MLIR  21.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1 //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
10 // correspond to the LLVM IR type system.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TypeDetail.h"
15 
18 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeSupport.h"
21 
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/TypeSize.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::LLVM;
29 
30 constexpr const static uint64_t kBitsInByte = 8;
31 
32 //===----------------------------------------------------------------------===//
33 // custom<FunctionTypes>
34 //===----------------------------------------------------------------------===//
35 
36 static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37  bool &isVarArg) {
38  isVarArg = false;
39  // `(` `)`
40  if (succeeded(p.parseOptionalRParen()))
41  return success();
42 
43  // `(` `...` `)`
44  if (succeeded(p.parseOptionalEllipsis())) {
45  isVarArg = true;
46  return p.parseRParen();
47  }
48 
49  // type (`,` type)* (`,` `...`)?
50  Type type;
51  if (parsePrettyLLVMType(p, type))
52  return failure();
53  params.push_back(type);
54  while (succeeded(p.parseOptionalComma())) {
55  if (succeeded(p.parseOptionalEllipsis())) {
56  isVarArg = true;
57  return p.parseRParen();
58  }
59  if (parsePrettyLLVMType(p, type))
60  return failure();
61  params.push_back(type);
62  }
63  return p.parseRParen();
64 }
65 
67  bool isVarArg) {
68  llvm::interleaveComma(params, p,
69  [&](Type type) { printPrettyLLVMType(p, type); });
70  if (isVarArg) {
71  if (!params.empty())
72  p << ", ";
73  p << "...";
74  }
75  p << ')';
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // custom<ExtTypeParams>
80 //===----------------------------------------------------------------------===//
81 
82 /// Parses the parameter list for a target extension type. The parameter list
83 /// contains an optional list of type parameters, followed by an optional list
84 /// of integer parameters. Type and integer parameters cannot be interleaved in
85 /// the list.
86 /// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87 /// typeList ::= type ("," type)*
88 /// intList ::= integer ("," integer)*
89 static ParseResult
91  SmallVectorImpl<unsigned int> &intParams) {
92  bool parseType = true;
93  auto typeOrIntParser = [&]() -> ParseResult {
94  unsigned int i;
95  auto intResult = p.parseOptionalInteger(i);
96  if (intResult.has_value() && !failed(*intResult)) {
97  // Successfully parsed an integer.
98  intParams.push_back(i);
99  // After the first integer was successfully parsed, no
100  // more types can be parsed.
101  parseType = false;
102  return success();
103  }
104  if (parseType) {
105  Type t;
106  if (!parsePrettyLLVMType(p, t)) {
107  // Successfully parsed a type.
108  typeParams.push_back(t);
109  return success();
110  }
111  }
112  return failure();
113  };
114  if (p.parseCommaSeparatedList(typeOrIntParser)) {
116  "failed to parse parameter list for target extension type");
117  return failure();
118  }
119  return success();
120 }
121 
122 static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
123  ArrayRef<unsigned int> intParams) {
124  p << typeParams;
125  if (!typeParams.empty() && !intParams.empty())
126  p << ", ";
127 
128  p << intParams;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // ODS-Generated Definitions
133 //===----------------------------------------------------------------------===//
134 
135 /// These are unused for now.
136 /// TODO: Move over to these once more types have been migrated to TypeDef.
137 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
138 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139 LLVM_ATTRIBUTE_UNUSED static LogicalResult
141 
142 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146 
147 //===----------------------------------------------------------------------===//
148 // LLVMArrayType
149 //===----------------------------------------------------------------------===//
150 
151 bool LLVMArrayType::isValidElementType(Type type) {
152  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153  LLVMFunctionType, LLVMTokenType>(type);
154 }
155 
156 LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
157  assert(elementType && "expected non-null subtype");
158  return Base::get(elementType.getContext(), elementType, numElements);
159 }
160 
161 LLVMArrayType
162 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
163  Type elementType, uint64_t numElements) {
164  assert(elementType && "expected non-null subtype");
165  return Base::getChecked(emitError, elementType.getContext(), elementType,
166  numElements);
167 }
168 
169 LogicalResult
171  Type elementType, uint64_t numElements) {
172  if (!isValidElementType(elementType))
173  return emitError() << "invalid array element type: " << elementType;
174  return success();
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // DataLayoutTypeInterface
179 //===----------------------------------------------------------------------===//
180 
181 llvm::TypeSize
182 LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183  DataLayoutEntryListRef params) const {
184  return llvm::TypeSize::getFixed(kBitsInByte *
185  getTypeSize(dataLayout, params));
186 }
187 
188 llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189  DataLayoutEntryListRef params) const {
190  return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191  dataLayout.getTypeABIAlignment(getElementType())) *
192  getNumElements();
193 }
194 
195 uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196  DataLayoutEntryListRef params) const {
197  return dataLayout.getTypeABIAlignment(getElementType());
198 }
199 
200 uint64_t
201 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202  DataLayoutEntryListRef params) const {
203  return dataLayout.getTypePreferredAlignment(getElementType());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Function type.
208 //===----------------------------------------------------------------------===//
209 
210 bool LLVMFunctionType::isValidArgumentType(Type type) {
211  return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212 }
213 
214 bool LLVMFunctionType::isValidResultType(Type type) {
215  return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216 }
217 
218 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219  bool isVarArg) {
220  assert(result && "expected non-null result");
221  return Base::get(result.getContext(), result, arguments, isVarArg);
222 }
223 
224 LLVMFunctionType
225 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226  Type result, ArrayRef<Type> arguments,
227  bool isVarArg) {
228  assert(result && "expected non-null result");
229  return Base::getChecked(emitError, result.getContext(), result, arguments,
230  isVarArg);
231 }
232 
233 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234  TypeRange results) const {
235  if (results.size() != 1 || !isValidResultType(results[0]))
236  return {};
237  if (!llvm::all_of(inputs, isValidArgumentType))
238  return {};
239  return get(results[0], llvm::to_vector(inputs), isVarArg());
240 }
241 
243  return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
244 }
245 
246 LogicalResult
248  Type result, ArrayRef<Type> arguments, bool) {
249  if (!isValidResultType(result))
250  return emitError() << "invalid function result type: " << result;
251 
252  for (Type arg : arguments)
253  if (!isValidArgumentType(arg))
254  return emitError() << "invalid function argument type: " << arg;
255 
256  return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // DataLayoutTypeInterface
261 //===----------------------------------------------------------------------===//
262 
263 constexpr const static uint64_t kDefaultPointerSizeBits = 64;
264 constexpr const static uint64_t kDefaultPointerAlignment = 8;
265 
266 std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
267  PtrDLEntryPos pos) {
268  auto spec = cast<DenseIntElementsAttr>(attr);
269  auto idx = static_cast<int64_t>(pos);
270  if (idx >= spec.size())
271  return std::nullopt;
272  return spec.getValues<uint64_t>()[idx];
273 }
274 
275 /// Returns the part of the data layout entry that corresponds to `pos` for the
276 /// given `type` by interpreting the list of entries `params`. For the pointer
277 /// type in the default address space, returns the default value if the entries
278 /// do not provide a custom one, for other address spaces returns std::nullopt.
279 static std::optional<uint64_t>
281  PtrDLEntryPos pos) {
282  // First, look for the entry for the pointer in the current address space.
283  Attribute currentEntry;
284  for (DataLayoutEntryInterface entry : params) {
285  if (!entry.isTypeEntry())
286  continue;
287  if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
288  type.getAddressSpace()) {
289  currentEntry = entry.getValue();
290  break;
291  }
292  }
293  if (currentEntry) {
294  std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
295  // If the optional `PtrDLEntryPos::Index` entry is not available, use the
296  // pointer size as the index bitwidth.
297  if (!value && pos == PtrDLEntryPos::Index)
298  value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
299  bool isSizeOrIndex =
300  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
301  return *value / (isSizeOrIndex ? 1 : kBitsInByte);
302  }
303 
304  // If not found, and this is the pointer to the default memory space, assume
305  // 64-bit pointers.
306  if (type.getAddressSpace() == 0) {
307  bool isSizeOrIndex =
308  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
309  return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
310  }
311 
312  return std::nullopt;
313 }
314 
315 llvm::TypeSize
316 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
317  DataLayoutEntryListRef params) const {
318  if (std::optional<uint64_t> size =
320  return llvm::TypeSize::getFixed(*size);
321 
322  // For other memory spaces, use the size of the pointer to the default memory
323  // space.
324  return dataLayout.getTypeSizeInBits(get(getContext()));
325 }
326 
327 uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
328  DataLayoutEntryListRef params) const {
329  if (std::optional<uint64_t> alignment =
331  return *alignment;
332 
333  return dataLayout.getTypeABIAlignment(get(getContext()));
334 }
335 
336 uint64_t
337 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
338  DataLayoutEntryListRef params) const {
339  if (std::optional<uint64_t> alignment =
341  return *alignment;
342 
343  return dataLayout.getTypePreferredAlignment(get(getContext()));
344 }
345 
346 std::optional<uint64_t>
348  DataLayoutEntryListRef params) const {
349  if (std::optional<uint64_t> indexBitwidth =
351  return *indexBitwidth;
352 
353  return dataLayout.getTypeIndexBitwidth(get(getContext()));
354 }
355 
356 bool LLVMPointerType::areCompatible(
357  DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
358  DataLayoutSpecInterface newSpec,
359  const DataLayoutIdentifiedEntryMap &map) const {
360  for (DataLayoutEntryInterface newEntry : newLayout) {
361  if (!newEntry.isTypeEntry())
362  continue;
363  uint64_t size = kDefaultPointerSizeBits;
364  uint64_t abi = kDefaultPointerAlignment;
365  auto newType =
366  llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
367  const auto *it =
368  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
369  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
370  return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
371  newType.getAddressSpace();
372  }
373  return false;
374  });
375  if (it == oldLayout.end()) {
376  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
377  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
378  return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
379  }
380  return false;
381  });
382  }
383  if (it != oldLayout.end()) {
386  }
387 
388  Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
389  uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
390  uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
391  if (size != newSize || abi < newAbi || abi % newAbi != 0)
392  return false;
393  }
394  return true;
395 }
396 
398  Location loc) const {
399  for (DataLayoutEntryInterface entry : entries) {
400  if (!entry.isTypeEntry())
401  continue;
402  auto key = llvm::cast<Type>(entry.getKey());
403  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
404  if (!values || (values.size() != 3 && values.size() != 4)) {
405  return emitError(loc)
406  << "expected layout attribute for " << key
407  << " to be a dense integer elements attribute with 3 or 4 "
408  "elements";
409  }
410  if (!values.getElementType().isInteger(64))
411  return emitError(loc) << "expected i64 parameters for " << key;
412 
415  return emitError(loc) << "preferred alignment is expected to be at least "
416  "as large as ABI alignment";
417  }
418  }
419  return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // Struct type.
424 //===----------------------------------------------------------------------===//
425 
426 bool LLVMStructType::isValidElementType(Type type) {
427  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
428  LLVMFunctionType, LLVMTokenType>(type);
429 }
430 
431 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
432  StringRef name) {
433  return Base::get(context, name, /*opaque=*/false);
434 }
435 
436 LLVMStructType LLVMStructType::getIdentifiedChecked(
438  StringRef name) {
439  return Base::getChecked(emitError, context, name, /*opaque=*/false);
440 }
441 
442 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
443  StringRef name,
444  ArrayRef<Type> elements,
445  bool isPacked) {
446  std::string stringName = name.str();
447  unsigned counter = 0;
448  do {
449  auto type = LLVMStructType::getIdentified(context, stringName);
450  if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
451  counter += 1;
452  stringName = (Twine(name) + "." + std::to_string(counter)).str();
453  continue;
454  }
455  return type;
456  } while (true);
457 }
458 
459 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
460  ArrayRef<Type> types, bool isPacked) {
461  return Base::get(context, types, isPacked);
462 }
463 
464 LLVMStructType
465 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
466  MLIRContext *context, ArrayRef<Type> types,
467  bool isPacked) {
468  return Base::getChecked(emitError, context, types, isPacked);
469 }
470 
471 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
472  return Base::get(context, name, /*opaque=*/true);
473 }
474 
475 LLVMStructType
476 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
477  MLIRContext *context, StringRef name) {
478  return Base::getChecked(emitError, context, name, /*opaque=*/true);
479 }
480 
481 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
482  assert(isIdentified() && "can only set bodies of identified structs");
483  assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
484  "expected valid body types");
485  return Base::mutate(types, isPacked);
486 }
487 
488 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
489 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
490 bool LLVMStructType::isOpaque() const {
491  return getImpl()->isIdentified() &&
492  (getImpl()->isOpaque() || !getImpl()->isInitialized());
493 }
494 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
495 StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
496 ArrayRef<Type> LLVMStructType::getBody() const {
497  return isIdentified() ? getImpl()->getIdentifiedStructBody()
498  : getImpl()->getTypeList();
499 }
500 
501 LogicalResult
502 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
503  bool) {
504  return success();
505 }
506 
507 LogicalResult
508 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
509  ArrayRef<Type> types, bool) {
510  for (Type t : types)
511  if (!isValidElementType(t))
512  return emitError() << "invalid LLVM structure element type: " << t;
513 
514  return success();
515 }
516 
517 llvm::TypeSize
518 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
519  DataLayoutEntryListRef params) const {
520  auto structSize = llvm::TypeSize::getFixed(0);
521  uint64_t structAlignment = 1;
522  for (Type element : getBody()) {
523  uint64_t elementAlignment =
524  isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
525  // Add padding to the struct size to align it to the abi alignment of the
526  // element type before than adding the size of the element.
527  structSize = llvm::alignTo(structSize, elementAlignment);
528  structSize += dataLayout.getTypeSize(element);
529 
530  // The alignment requirement of a struct is equal to the strictest alignment
531  // requirement of its elements.
532  structAlignment = std::max(elementAlignment, structAlignment);
533  }
534  // At the end, add padding to the struct to satisfy its own alignment
535  // requirement. Otherwise structs inside of arrays would be misaligned.
536  structSize = llvm::alignTo(structSize, structAlignment);
537  return structSize * kBitsInByte;
538 }
539 
540 namespace {
541 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
542 } // namespace
543 
544 static std::optional<uint64_t>
546  StructDLEntryPos pos) {
547  const auto *currentEntry =
548  llvm::find_if(params, [](DataLayoutEntryInterface entry) {
549  return entry.isTypeEntry();
550  });
551  if (currentEntry == params.end())
552  return std::nullopt;
553 
554  auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
555  if (pos == StructDLEntryPos::Preferred &&
556  attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
557  // If no preferred was specified, fall back to abi alignment
558  pos = StructDLEntryPos::Abi;
559 
560  return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
561 }
562 
563 static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
564  DataLayoutEntryListRef params,
565  LLVMStructType type,
566  StructDLEntryPos pos) {
567  // Packed structs always have an abi alignment of 1
568  if (pos == StructDLEntryPos::Abi && type.isPacked()) {
569  return 1;
570  }
571 
572  // The alignment requirement of a struct is equal to the strictest alignment
573  // requirement of its elements.
574  uint64_t structAlignment = 1;
575  for (Type iter : type.getBody()) {
576  structAlignment =
577  std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
578  }
579 
580  // Entries are only allowed to be stricter than the required alignment
581  if (std::optional<uint64_t> entryResult =
582  getStructDataLayoutEntry(params, type, pos))
583  return std::max(*entryResult / kBitsInByte, structAlignment);
584 
585  return structAlignment;
586 }
587 
588 uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
589  DataLayoutEntryListRef params) const {
590  return calculateStructAlignment(dataLayout, params, *this,
591  StructDLEntryPos::Abi);
592 }
593 
594 uint64_t
595 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
596  DataLayoutEntryListRef params) const {
597  return calculateStructAlignment(dataLayout, params, *this,
598  StructDLEntryPos::Preferred);
599 }
600 
601 static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
602  return llvm::cast<DenseIntElementsAttr>(attr)
603  .getValues<uint64_t>()[static_cast<size_t>(pos)];
604 }
605 
606 bool LLVMStructType::areCompatible(
607  DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
608  DataLayoutSpecInterface newSpec,
609  const DataLayoutIdentifiedEntryMap &map) const {
610  for (DataLayoutEntryInterface newEntry : newLayout) {
611  if (!newEntry.isTypeEntry())
612  continue;
613 
614  const auto *previousEntry =
615  llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
616  return entry.isTypeEntry();
617  });
618  if (previousEntry == oldLayout.end())
619  continue;
620 
621  uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
622  StructDLEntryPos::Abi);
623  uint64_t newAbi =
624  extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
625  if (abi < newAbi || abi % newAbi != 0)
626  return false;
627  }
628  return true;
629 }
630 
632  Location loc) const {
633  for (DataLayoutEntryInterface entry : entries) {
634  if (!entry.isTypeEntry())
635  continue;
636 
637  auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
638  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
639  if (!values || (values.size() != 2 && values.size() != 1)) {
640  return emitError(loc)
641  << "expected layout attribute for "
642  << llvm::cast<Type>(entry.getKey())
643  << " to be a dense integer elements attribute of 1 or 2 elements";
644  }
645  if (!values.getElementType().isInteger(64))
646  return emitError(loc) << "expected i64 entries for " << key;
647 
648  if (key.isIdentified() || !key.getBody().empty()) {
649  return emitError(loc) << "unexpected layout attribute for struct " << key;
650  }
651 
652  if (values.size() == 1)
653  continue;
654 
655  if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
656  extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
657  return emitError(loc) << "preferred alignment is expected to be at least "
658  "as large as ABI alignment";
659  }
660  }
661  return mlir::success();
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // LLVMTargetExtType.
666 //===----------------------------------------------------------------------===//
667 
668 static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
669 static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
670 
671 bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
672  // See llvm/lib/IR/Type.cpp for reference.
673  uint64_t properties = 0;
674 
675  if (getExtTypeName().starts_with(kSpirvPrefix))
676  properties |=
677  (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
678 
679  return (properties & prop) == prop;
680 }
681 
682 bool LLVM::LLVMTargetExtType::supportsMemOps() const {
683  // See llvm/lib/IR/Type.cpp for reference.
684  if (getExtTypeName().starts_with(kSpirvPrefix))
685  return true;
686 
687  if (getExtTypeName() == kArmSVCount)
688  return true;
689 
690  return false;
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // LLVMPPCFP128Type
695 //===----------------------------------------------------------------------===//
696 
697 const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
698  return APFloat::PPCDoubleDouble();
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // Utility functions.
703 //===----------------------------------------------------------------------===//
704 
706  // clang-format off
707  if (llvm::isa<
708  BFloat16Type,
709  Float16Type,
710  Float32Type,
711  Float64Type,
712  Float80Type,
713  Float128Type,
714  LLVMArrayType,
715  LLVMFunctionType,
716  LLVMLabelType,
717  LLVMMetadataType,
718  LLVMPPCFP128Type,
719  LLVMPointerType,
720  LLVMStructType,
721  LLVMTokenType,
722  LLVMTargetExtType,
723  LLVMVoidType,
724  LLVMX86AMXType
725  >(type)) {
726  // clang-format on
727  return true;
728  }
729 
730  // Only signless integers are compatible.
731  if (auto intType = llvm::dyn_cast<IntegerType>(type))
732  return intType.isSignless();
733 
734  // 1D vector types are compatible.
735  if (auto vecType = llvm::dyn_cast<VectorType>(type))
736  return vecType.getRank() == 1;
737 
738  return false;
739 }
740 
741 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
742  if (!compatibleTypes.insert(type).second)
743  return true;
744 
745  auto isCompatible = [&](Type type) {
746  return isCompatibleImpl(type, compatibleTypes);
747  };
748 
749  bool result =
751  .Case<LLVMStructType>([&](auto structType) {
752  return llvm::all_of(structType.getBody(), isCompatible);
753  })
754  .Case<LLVMFunctionType>([&](auto funcType) {
755  return isCompatible(funcType.getReturnType()) &&
756  llvm::all_of(funcType.getParams(), isCompatible);
757  })
758  .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
759  .Case<VectorType>([&](auto vecType) {
760  return vecType.getRank() == 1 &&
761  isCompatible(vecType.getElementType());
762  })
763  .Case<LLVMPointerType>([&](auto pointerType) { return true; })
764  .Case<LLVMTargetExtType>([&](auto extType) {
765  return llvm::all_of(extType.getTypeParams(), isCompatible);
766  })
767  // clang-format off
768  .Case<
769  LLVMArrayType
770  >([&](auto containerType) {
771  return isCompatible(containerType.getElementType());
772  })
773  .Case<
774  BFloat16Type,
775  Float16Type,
776  Float32Type,
777  Float64Type,
778  Float80Type,
779  Float128Type,
780  LLVMLabelType,
781  LLVMMetadataType,
782  LLVMPPCFP128Type,
783  LLVMTokenType,
784  LLVMVoidType,
785  LLVMX86AMXType
786  >([](Type) { return true; })
787  // clang-format on
788  .Default([](Type) { return false; });
789 
790  if (!result)
791  compatibleTypes.erase(type);
792 
793  return result;
794 }
795 
797  if (auto *llvmDialect =
798  type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
799  return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
800 
801  DenseSet<Type> localCompatibleTypes;
802  return isCompatibleImpl(type, localCompatibleTypes);
803 }
804 
806  return LLVMDialect::isCompatibleType(type);
807 }
808 
810  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
811  Float80Type, Float128Type, LLVMPPCFP128Type>(type);
812 }
813 
815  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
816  if (vecType.getRank() != 1)
817  return false;
818  Type elementType = vecType.getElementType();
819  if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
820  return intType.isSignless();
821  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
822  Float80Type, Float128Type, LLVMPointerType>(elementType);
823  }
824  return false;
825 }
826 
827 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
828  auto vecTy = dyn_cast<VectorType>(type);
829  assert(vecTy && "incompatible with LLVM vector type");
830  if (vecTy.isScalable())
831  return llvm::ElementCount::getScalable(vecTy.getNumElements());
832  return llvm::ElementCount::getFixed(vecTy.getNumElements());
833 }
834 
836  assert(llvm::isa<VectorType>(vectorType) &&
837  "expected LLVM-compatible vector type");
838  return llvm::cast<VectorType>(vectorType).isScalable();
839 }
840 
841 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
842  bool isScalable) {
843  assert(VectorType::isValidElementType(elementType) &&
844  "incompatible element type");
845  return VectorType::get(numElements, elementType, {isScalable});
846 }
847 
849  const llvm::ElementCount &numElements) {
850  if (numElements.isScalable())
851  return getVectorType(elementType, numElements.getKnownMinValue(),
852  /*isScalable=*/true);
853  return getVectorType(elementType, numElements.getFixedValue(),
854  /*isScalable=*/false);
855 }
856 
858  assert(isCompatibleType(type) &&
859  "expected a type compatible with the LLVM dialect");
860 
862  .Case<BFloat16Type, Float16Type>(
863  [](Type) { return llvm::TypeSize::getFixed(16); })
864  .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
865  .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
866  .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
867  .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
868  .Case<IntegerType>([](IntegerType intTy) {
869  return llvm::TypeSize::getFixed(intTy.getWidth());
870  })
871  .Case<LLVMPPCFP128Type>(
872  [](Type) { return llvm::TypeSize::getFixed(128); })
873  .Case<VectorType>([](VectorType t) {
874  assert(isCompatibleVectorType(t) &&
875  "unexpected incompatible with LLVM vector type");
876  llvm::TypeSize elementSize =
877  getPrimitiveTypeSizeInBits(t.getElementType());
878  return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
879  elementSize.isScalable());
880  })
881  .Default([](Type ty) {
882  assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
883  LLVMTokenType, LLVMStructType, LLVMArrayType,
884  LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
885  ty)) &&
886  "unexpected missing support for primitive type");
887  return llvm::TypeSize::getFixed(0);
888  });
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // LLVMDialect
893 //===----------------------------------------------------------------------===//
894 
895 void LLVMDialect::registerTypes() {
896  addTypes<
897 #define GET_TYPEDEF_LIST
898 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
899  >();
900 }
901 
903  return detail::parseType(parser);
904 }
905 
906 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
907  return detail::printType(type, os);
908 }
static LogicalResult verifyEntries(function_ref< InFlightDiagnostic()> emitError, ArrayRef< DataLayoutEntryInterface > entries, bool allowTypes=true)
Verify entries, with the option to disallow types as keys.
Definition: DLTI.cpp:136
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
Definition: LLVMTypes.cpp:264
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
Definition: LLVMTypes.cpp:122
constexpr static const uint64_t kDefaultPointerSizeBits
Definition: LLVMTypes.cpp:263
static LLVM_ATTRIBUTE_UNUSED OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition: LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
Definition: LLVMTypes.cpp:741
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition: LLVMTypes.cpp:90
static LLVM_ATTRIBUTE_UNUSED LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static constexpr llvm::StringRef kSpirvPrefix
Definition: LLVMTypes.cpp:668
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition: LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition: LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
Definition: LLVMTypes.cpp:669
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
Definition: LLVMTypes.cpp:280
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:545
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:601
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:563
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:841
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
Definition: LLVMTypes.cpp:857
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
Definition: LLVMTypes.cpp:835
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:814
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
Definition: LLVMTypes.cpp:705
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition: LLVMTypes.h:136
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
Definition: LLVMTypes.cpp:266
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:827
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
::llvm::MapVector<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface > DataLayoutIdentifiedEntryMap
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423