MLIR  21.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1 //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
10 // correspond to the LLVM IR type system.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TypeDetail.h"
15 
18 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeSupport.h"
21 
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/TypeSize.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::LLVM;
29 
30 constexpr const static uint64_t kBitsInByte = 8;
31 
32 //===----------------------------------------------------------------------===//
33 // custom<FunctionTypes>
34 //===----------------------------------------------------------------------===//
35 
36 static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37  bool &isVarArg) {
38  isVarArg = false;
39  // `(` `)`
40  if (succeeded(p.parseOptionalRParen()))
41  return success();
42 
43  // `(` `...` `)`
44  if (succeeded(p.parseOptionalEllipsis())) {
45  isVarArg = true;
46  return p.parseRParen();
47  }
48 
49  // type (`,` type)* (`,` `...`)?
50  Type type;
51  if (parsePrettyLLVMType(p, type))
52  return failure();
53  params.push_back(type);
54  while (succeeded(p.parseOptionalComma())) {
55  if (succeeded(p.parseOptionalEllipsis())) {
56  isVarArg = true;
57  return p.parseRParen();
58  }
59  if (parsePrettyLLVMType(p, type))
60  return failure();
61  params.push_back(type);
62  }
63  return p.parseRParen();
64 }
65 
67  bool isVarArg) {
68  llvm::interleaveComma(params, p,
69  [&](Type type) { printPrettyLLVMType(p, type); });
70  if (isVarArg) {
71  if (!params.empty())
72  p << ", ";
73  p << "...";
74  }
75  p << ')';
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // custom<ExtTypeParams>
80 //===----------------------------------------------------------------------===//
81 
82 /// Parses the parameter list for a target extension type. The parameter list
83 /// contains an optional list of type parameters, followed by an optional list
84 /// of integer parameters. Type and integer parameters cannot be interleaved in
85 /// the list.
86 /// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87 /// typeList ::= type ("," type)*
88 /// intList ::= integer ("," integer)*
89 static ParseResult
91  SmallVectorImpl<unsigned int> &intParams) {
92  bool parseType = true;
93  auto typeOrIntParser = [&]() -> ParseResult {
94  unsigned int i;
95  auto intResult = p.parseOptionalInteger(i);
96  if (intResult.has_value() && !failed(*intResult)) {
97  // Successfully parsed an integer.
98  intParams.push_back(i);
99  // After the first integer was successfully parsed, no
100  // more types can be parsed.
101  parseType = false;
102  return success();
103  }
104  if (parseType) {
105  Type t;
106  if (!parsePrettyLLVMType(p, t)) {
107  // Successfully parsed a type.
108  typeParams.push_back(t);
109  return success();
110  }
111  }
112  return failure();
113  };
114  if (p.parseCommaSeparatedList(typeOrIntParser)) {
116  "failed to parse parameter list for target extension type");
117  return failure();
118  }
119  return success();
120 }
121 
122 static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
123  ArrayRef<unsigned int> intParams) {
124  p << typeParams;
125  if (!typeParams.empty() && !intParams.empty())
126  p << ", ";
127 
128  p << intParams;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // ODS-Generated Definitions
133 //===----------------------------------------------------------------------===//
134 
135 /// These are unused for now.
136 /// TODO: Move over to these once more types have been migrated to TypeDef.
137 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
138 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139 LLVM_ATTRIBUTE_UNUSED static LogicalResult
141 
142 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146 
147 //===----------------------------------------------------------------------===//
148 // LLVMArrayType
149 //===----------------------------------------------------------------------===//
150 
151 bool LLVMArrayType::isValidElementType(Type type) {
152  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153  LLVMFunctionType, LLVMTokenType>(type);
154 }
155 
156 LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
157  assert(elementType && "expected non-null subtype");
158  return Base::get(elementType.getContext(), elementType, numElements);
159 }
160 
161 LLVMArrayType
162 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
163  Type elementType, uint64_t numElements) {
164  assert(elementType && "expected non-null subtype");
165  return Base::getChecked(emitError, elementType.getContext(), elementType,
166  numElements);
167 }
168 
169 LogicalResult
171  Type elementType, uint64_t numElements) {
172  if (!isValidElementType(elementType))
173  return emitError() << "invalid array element type: " << elementType;
174  return success();
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // DataLayoutTypeInterface
179 //===----------------------------------------------------------------------===//
180 
181 llvm::TypeSize
182 LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183  DataLayoutEntryListRef params) const {
184  return llvm::TypeSize::getFixed(kBitsInByte *
185  getTypeSize(dataLayout, params));
186 }
187 
188 llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189  DataLayoutEntryListRef params) const {
190  return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191  dataLayout.getTypeABIAlignment(getElementType())) *
192  getNumElements();
193 }
194 
195 uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196  DataLayoutEntryListRef params) const {
197  return dataLayout.getTypeABIAlignment(getElementType());
198 }
199 
200 uint64_t
201 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202  DataLayoutEntryListRef params) const {
203  return dataLayout.getTypePreferredAlignment(getElementType());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Function type.
208 //===----------------------------------------------------------------------===//
209 
210 bool LLVMFunctionType::isValidArgumentType(Type type) {
211  return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212 }
213 
214 bool LLVMFunctionType::isValidResultType(Type type) {
215  return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216 }
217 
218 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219  bool isVarArg) {
220  assert(result && "expected non-null result");
221  return Base::get(result.getContext(), result, arguments, isVarArg);
222 }
223 
224 LLVMFunctionType
225 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226  Type result, ArrayRef<Type> arguments,
227  bool isVarArg) {
228  assert(result && "expected non-null result");
229  return Base::getChecked(emitError, result.getContext(), result, arguments,
230  isVarArg);
231 }
232 
233 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234  TypeRange results) const {
235  assert(results.size() == 1 && "expected a single result type");
236  return get(results[0], llvm::to_vector(inputs), isVarArg());
237 }
238 
240  return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
241 }
242 
243 LogicalResult
245  Type result, ArrayRef<Type> arguments, bool) {
246  if (!isValidResultType(result))
247  return emitError() << "invalid function result type: " << result;
248 
249  for (Type arg : arguments)
250  if (!isValidArgumentType(arg))
251  return emitError() << "invalid function argument type: " << arg;
252 
253  return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // DataLayoutTypeInterface
258 //===----------------------------------------------------------------------===//
259 
260 constexpr const static uint64_t kDefaultPointerSizeBits = 64;
261 constexpr const static uint64_t kDefaultPointerAlignment = 8;
262 
263 std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
264  PtrDLEntryPos pos) {
265  auto spec = cast<DenseIntElementsAttr>(attr);
266  auto idx = static_cast<int64_t>(pos);
267  if (idx >= spec.size())
268  return std::nullopt;
269  return spec.getValues<uint64_t>()[idx];
270 }
271 
272 /// Returns the part of the data layout entry that corresponds to `pos` for the
273 /// given `type` by interpreting the list of entries `params`. For the pointer
274 /// type in the default address space, returns the default value if the entries
275 /// do not provide a custom one, for other address spaces returns std::nullopt.
276 static std::optional<uint64_t>
278  PtrDLEntryPos pos) {
279  // First, look for the entry for the pointer in the current address space.
280  Attribute currentEntry;
281  for (DataLayoutEntryInterface entry : params) {
282  if (!entry.isTypeEntry())
283  continue;
284  if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
285  type.getAddressSpace()) {
286  currentEntry = entry.getValue();
287  break;
288  }
289  }
290  if (currentEntry) {
291  std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
292  // If the optional `PtrDLEntryPos::Index` entry is not available, use the
293  // pointer size as the index bitwidth.
294  if (!value && pos == PtrDLEntryPos::Index)
295  value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
296  bool isSizeOrIndex =
297  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
298  return *value / (isSizeOrIndex ? 1 : kBitsInByte);
299  }
300 
301  // If not found, and this is the pointer to the default memory space, assume
302  // 64-bit pointers.
303  if (type.getAddressSpace() == 0) {
304  bool isSizeOrIndex =
305  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
306  return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
307  }
308 
309  return std::nullopt;
310 }
311 
312 llvm::TypeSize
313 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
314  DataLayoutEntryListRef params) const {
315  if (std::optional<uint64_t> size =
317  return llvm::TypeSize::getFixed(*size);
318 
319  // For other memory spaces, use the size of the pointer to the default memory
320  // space.
321  return dataLayout.getTypeSizeInBits(get(getContext()));
322 }
323 
324 uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
325  DataLayoutEntryListRef params) const {
326  if (std::optional<uint64_t> alignment =
328  return *alignment;
329 
330  return dataLayout.getTypeABIAlignment(get(getContext()));
331 }
332 
333 uint64_t
334 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
335  DataLayoutEntryListRef params) const {
336  if (std::optional<uint64_t> alignment =
338  return *alignment;
339 
340  return dataLayout.getTypePreferredAlignment(get(getContext()));
341 }
342 
343 std::optional<uint64_t>
345  DataLayoutEntryListRef params) const {
346  if (std::optional<uint64_t> indexBitwidth =
348  return *indexBitwidth;
349 
350  return dataLayout.getTypeIndexBitwidth(get(getContext()));
351 }
352 
353 bool LLVMPointerType::areCompatible(
354  DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
355  DataLayoutSpecInterface newSpec,
356  const DataLayoutIdentifiedEntryMap &map) const {
357  for (DataLayoutEntryInterface newEntry : newLayout) {
358  if (!newEntry.isTypeEntry())
359  continue;
360  uint64_t size = kDefaultPointerSizeBits;
361  uint64_t abi = kDefaultPointerAlignment;
362  auto newType =
363  llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
364  const auto *it =
365  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
366  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
367  return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
368  newType.getAddressSpace();
369  }
370  return false;
371  });
372  if (it == oldLayout.end()) {
373  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
374  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
375  return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
376  }
377  return false;
378  });
379  }
380  if (it != oldLayout.end()) {
383  }
384 
385  Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
386  uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
387  uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
388  if (size != newSize || abi < newAbi || abi % newAbi != 0)
389  return false;
390  }
391  return true;
392 }
393 
395  Location loc) const {
396  for (DataLayoutEntryInterface entry : entries) {
397  if (!entry.isTypeEntry())
398  continue;
399  auto key = llvm::cast<Type>(entry.getKey());
400  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
401  if (!values || (values.size() != 3 && values.size() != 4)) {
402  return emitError(loc)
403  << "expected layout attribute for " << key
404  << " to be a dense integer elements attribute with 3 or 4 "
405  "elements";
406  }
407  if (!values.getElementType().isInteger(64))
408  return emitError(loc) << "expected i64 parameters for " << key;
409 
412  return emitError(loc) << "preferred alignment is expected to be at least "
413  "as large as ABI alignment";
414  }
415  }
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // Struct type.
421 //===----------------------------------------------------------------------===//
422 
423 bool LLVMStructType::isValidElementType(Type type) {
424  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
425  LLVMFunctionType, LLVMTokenType>(type);
426 }
427 
428 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
429  StringRef name) {
430  return Base::get(context, name, /*opaque=*/false);
431 }
432 
433 LLVMStructType LLVMStructType::getIdentifiedChecked(
435  StringRef name) {
436  return Base::getChecked(emitError, context, name, /*opaque=*/false);
437 }
438 
439 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
440  StringRef name,
441  ArrayRef<Type> elements,
442  bool isPacked) {
443  std::string stringName = name.str();
444  unsigned counter = 0;
445  do {
446  auto type = LLVMStructType::getIdentified(context, stringName);
447  if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
448  counter += 1;
449  stringName = (Twine(name) + "." + std::to_string(counter)).str();
450  continue;
451  }
452  return type;
453  } while (true);
454 }
455 
456 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
457  ArrayRef<Type> types, bool isPacked) {
458  return Base::get(context, types, isPacked);
459 }
460 
461 LLVMStructType
462 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
463  MLIRContext *context, ArrayRef<Type> types,
464  bool isPacked) {
465  return Base::getChecked(emitError, context, types, isPacked);
466 }
467 
468 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
469  return Base::get(context, name, /*opaque=*/true);
470 }
471 
472 LLVMStructType
473 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
474  MLIRContext *context, StringRef name) {
475  return Base::getChecked(emitError, context, name, /*opaque=*/true);
476 }
477 
478 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
479  assert(isIdentified() && "can only set bodies of identified structs");
480  assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
481  "expected valid body types");
482  return Base::mutate(types, isPacked);
483 }
484 
485 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
486 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
487 bool LLVMStructType::isOpaque() const {
488  return getImpl()->isIdentified() &&
489  (getImpl()->isOpaque() || !getImpl()->isInitialized());
490 }
491 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
492 StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
493 ArrayRef<Type> LLVMStructType::getBody() const {
494  return isIdentified() ? getImpl()->getIdentifiedStructBody()
495  : getImpl()->getTypeList();
496 }
497 
498 LogicalResult
499 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
500  bool) {
501  return success();
502 }
503 
504 LogicalResult
505 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
506  ArrayRef<Type> types, bool) {
507  for (Type t : types)
508  if (!isValidElementType(t))
509  return emitError() << "invalid LLVM structure element type: " << t;
510 
511  return success();
512 }
513 
514 llvm::TypeSize
515 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
516  DataLayoutEntryListRef params) const {
517  auto structSize = llvm::TypeSize::getFixed(0);
518  uint64_t structAlignment = 1;
519  for (Type element : getBody()) {
520  uint64_t elementAlignment =
521  isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
522  // Add padding to the struct size to align it to the abi alignment of the
523  // element type before than adding the size of the element.
524  structSize = llvm::alignTo(structSize, elementAlignment);
525  structSize += dataLayout.getTypeSize(element);
526 
527  // The alignment requirement of a struct is equal to the strictest alignment
528  // requirement of its elements.
529  structAlignment = std::max(elementAlignment, structAlignment);
530  }
531  // At the end, add padding to the struct to satisfy its own alignment
532  // requirement. Otherwise structs inside of arrays would be misaligned.
533  structSize = llvm::alignTo(structSize, structAlignment);
534  return structSize * kBitsInByte;
535 }
536 
537 namespace {
538 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
539 } // namespace
540 
541 static std::optional<uint64_t>
543  StructDLEntryPos pos) {
544  const auto *currentEntry =
545  llvm::find_if(params, [](DataLayoutEntryInterface entry) {
546  return entry.isTypeEntry();
547  });
548  if (currentEntry == params.end())
549  return std::nullopt;
550 
551  auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
552  if (pos == StructDLEntryPos::Preferred &&
553  attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
554  // If no preferred was specified, fall back to abi alignment
555  pos = StructDLEntryPos::Abi;
556 
557  return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
558 }
559 
560 static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
561  DataLayoutEntryListRef params,
562  LLVMStructType type,
563  StructDLEntryPos pos) {
564  // Packed structs always have an abi alignment of 1
565  if (pos == StructDLEntryPos::Abi && type.isPacked()) {
566  return 1;
567  }
568 
569  // The alignment requirement of a struct is equal to the strictest alignment
570  // requirement of its elements.
571  uint64_t structAlignment = 1;
572  for (Type iter : type.getBody()) {
573  structAlignment =
574  std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
575  }
576 
577  // Entries are only allowed to be stricter than the required alignment
578  if (std::optional<uint64_t> entryResult =
579  getStructDataLayoutEntry(params, type, pos))
580  return std::max(*entryResult / kBitsInByte, structAlignment);
581 
582  return structAlignment;
583 }
584 
585 uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
586  DataLayoutEntryListRef params) const {
587  return calculateStructAlignment(dataLayout, params, *this,
588  StructDLEntryPos::Abi);
589 }
590 
591 uint64_t
592 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
593  DataLayoutEntryListRef params) const {
594  return calculateStructAlignment(dataLayout, params, *this,
595  StructDLEntryPos::Preferred);
596 }
597 
598 static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
599  return llvm::cast<DenseIntElementsAttr>(attr)
600  .getValues<uint64_t>()[static_cast<size_t>(pos)];
601 }
602 
603 bool LLVMStructType::areCompatible(
604  DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
605  DataLayoutSpecInterface newSpec,
606  const DataLayoutIdentifiedEntryMap &map) const {
607  for (DataLayoutEntryInterface newEntry : newLayout) {
608  if (!newEntry.isTypeEntry())
609  continue;
610 
611  const auto *previousEntry =
612  llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
613  return entry.isTypeEntry();
614  });
615  if (previousEntry == oldLayout.end())
616  continue;
617 
618  uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
619  StructDLEntryPos::Abi);
620  uint64_t newAbi =
621  extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
622  if (abi < newAbi || abi % newAbi != 0)
623  return false;
624  }
625  return true;
626 }
627 
629  Location loc) const {
630  for (DataLayoutEntryInterface entry : entries) {
631  if (!entry.isTypeEntry())
632  continue;
633 
634  auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
635  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
636  if (!values || (values.size() != 2 && values.size() != 1)) {
637  return emitError(loc)
638  << "expected layout attribute for "
639  << llvm::cast<Type>(entry.getKey())
640  << " to be a dense integer elements attribute of 1 or 2 elements";
641  }
642  if (!values.getElementType().isInteger(64))
643  return emitError(loc) << "expected i64 entries for " << key;
644 
645  if (key.isIdentified() || !key.getBody().empty()) {
646  return emitError(loc) << "unexpected layout attribute for struct " << key;
647  }
648 
649  if (values.size() == 1)
650  continue;
651 
652  if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
653  extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
654  return emitError(loc) << "preferred alignment is expected to be at least "
655  "as large as ABI alignment";
656  }
657  }
658  return mlir::success();
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // LLVMTargetExtType.
663 //===----------------------------------------------------------------------===//
664 
665 static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
666 static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
667 
668 bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
669  // See llvm/lib/IR/Type.cpp for reference.
670  uint64_t properties = 0;
671 
672  if (getExtTypeName().starts_with(kSpirvPrefix))
673  properties |=
674  (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
675 
676  return (properties & prop) == prop;
677 }
678 
679 bool LLVM::LLVMTargetExtType::supportsMemOps() const {
680  // See llvm/lib/IR/Type.cpp for reference.
681  if (getExtTypeName().starts_with(kSpirvPrefix))
682  return true;
683 
684  if (getExtTypeName() == kArmSVCount)
685  return true;
686 
687  return false;
688 }
689 
690 //===----------------------------------------------------------------------===//
691 // LLVMPPCFP128Type
692 //===----------------------------------------------------------------------===//
693 
694 const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
695  return APFloat::PPCDoubleDouble();
696 }
697 
698 //===----------------------------------------------------------------------===//
699 // Utility functions.
700 //===----------------------------------------------------------------------===//
701 
703  // clang-format off
704  if (llvm::isa<
705  BFloat16Type,
706  Float16Type,
707  Float32Type,
708  Float64Type,
709  Float80Type,
710  Float128Type,
711  LLVMArrayType,
712  LLVMFunctionType,
713  LLVMLabelType,
714  LLVMMetadataType,
715  LLVMPPCFP128Type,
716  LLVMPointerType,
717  LLVMStructType,
718  LLVMTokenType,
719  LLVMTargetExtType,
720  LLVMVoidType,
721  LLVMX86AMXType
722  >(type)) {
723  // clang-format on
724  return true;
725  }
726 
727  // Only signless integers are compatible.
728  if (auto intType = llvm::dyn_cast<IntegerType>(type))
729  return intType.isSignless();
730 
731  // 1D vector types are compatible.
732  if (auto vecType = llvm::dyn_cast<VectorType>(type))
733  return vecType.getRank() == 1;
734 
735  return false;
736 }
737 
738 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
739  if (!compatibleTypes.insert(type).second)
740  return true;
741 
742  auto isCompatible = [&](Type type) {
743  return isCompatibleImpl(type, compatibleTypes);
744  };
745 
746  bool result =
748  .Case<LLVMStructType>([&](auto structType) {
749  return llvm::all_of(structType.getBody(), isCompatible);
750  })
751  .Case<LLVMFunctionType>([&](auto funcType) {
752  return isCompatible(funcType.getReturnType()) &&
753  llvm::all_of(funcType.getParams(), isCompatible);
754  })
755  .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
756  .Case<VectorType>([&](auto vecType) {
757  return vecType.getRank() == 1 &&
758  isCompatible(vecType.getElementType());
759  })
760  .Case<LLVMPointerType>([&](auto pointerType) { return true; })
761  .Case<LLVMTargetExtType>([&](auto extType) {
762  return llvm::all_of(extType.getTypeParams(), isCompatible);
763  })
764  // clang-format off
765  .Case<
766  LLVMArrayType
767  >([&](auto containerType) {
768  return isCompatible(containerType.getElementType());
769  })
770  .Case<
771  BFloat16Type,
772  Float16Type,
773  Float32Type,
774  Float64Type,
775  Float80Type,
776  Float128Type,
777  LLVMLabelType,
778  LLVMMetadataType,
779  LLVMPPCFP128Type,
780  LLVMTokenType,
781  LLVMVoidType,
782  LLVMX86AMXType
783  >([](Type) { return true; })
784  // clang-format on
785  .Default([](Type) { return false; });
786 
787  if (!result)
788  compatibleTypes.erase(type);
789 
790  return result;
791 }
792 
794  if (auto *llvmDialect =
795  type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
796  return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
797 
798  DenseSet<Type> localCompatibleTypes;
799  return isCompatibleImpl(type, localCompatibleTypes);
800 }
801 
803  return LLVMDialect::isCompatibleType(type);
804 }
805 
807  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
808  Float80Type, Float128Type, LLVMPPCFP128Type>(type);
809 }
810 
812  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
813  if (vecType.getRank() != 1)
814  return false;
815  Type elementType = vecType.getElementType();
816  if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
817  return intType.isSignless();
818  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
819  Float80Type, Float128Type, LLVMPointerType>(elementType);
820  }
821  return false;
822 }
823 
824 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
825  auto vecTy = dyn_cast<VectorType>(type);
826  assert(vecTy && "incompatible with LLVM vector type");
827  if (vecTy.isScalable())
828  return llvm::ElementCount::getScalable(vecTy.getNumElements());
829  return llvm::ElementCount::getFixed(vecTy.getNumElements());
830 }
831 
833  assert(llvm::isa<VectorType>(vectorType) &&
834  "expected LLVM-compatible vector type");
835  return llvm::cast<VectorType>(vectorType).isScalable();
836 }
837 
838 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
839  bool isScalable) {
840  assert(VectorType::isValidElementType(elementType) &&
841  "incompatible element type");
842  return VectorType::get(numElements, elementType, {isScalable});
843 }
844 
846  const llvm::ElementCount &numElements) {
847  if (numElements.isScalable())
848  return getVectorType(elementType, numElements.getKnownMinValue(),
849  /*isScalable=*/true);
850  return getVectorType(elementType, numElements.getFixedValue(),
851  /*isScalable=*/false);
852 }
853 
855  assert(isCompatibleType(type) &&
856  "expected a type compatible with the LLVM dialect");
857 
859  .Case<BFloat16Type, Float16Type>(
860  [](Type) { return llvm::TypeSize::getFixed(16); })
861  .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
862  .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
863  .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
864  .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
865  .Case<IntegerType>([](IntegerType intTy) {
866  return llvm::TypeSize::getFixed(intTy.getWidth());
867  })
868  .Case<LLVMPPCFP128Type>(
869  [](Type) { return llvm::TypeSize::getFixed(128); })
870  .Case<VectorType>([](VectorType t) {
871  assert(isCompatibleVectorType(t) &&
872  "unexpected incompatible with LLVM vector type");
873  llvm::TypeSize elementSize =
874  getPrimitiveTypeSizeInBits(t.getElementType());
875  return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
876  elementSize.isScalable());
877  })
878  .Default([](Type ty) {
879  assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
880  LLVMTokenType, LLVMStructType, LLVMArrayType,
881  LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
882  ty)) &&
883  "unexpected missing support for primitive type");
884  return llvm::TypeSize::getFixed(0);
885  });
886 }
887 
888 //===----------------------------------------------------------------------===//
889 // LLVMDialect
890 //===----------------------------------------------------------------------===//
891 
892 void LLVMDialect::registerTypes() {
893  addTypes<
894 #define GET_TYPEDEF_LIST
895 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
896  >();
897 }
898 
900  return detail::parseType(parser);
901 }
902 
903 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
904  return detail::printType(type, os);
905 }
static LogicalResult verifyEntries(function_ref< InFlightDiagnostic()> emitError, ArrayRef< DataLayoutEntryInterface > entries, bool allowTypes=true)
Verify entries, with the option to disallow types as keys.
Definition: DLTI.cpp:136
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
Definition: LLVMTypes.cpp:261
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
Definition: LLVMTypes.cpp:122
constexpr static const uint64_t kDefaultPointerSizeBits
Definition: LLVMTypes.cpp:260
static LLVM_ATTRIBUTE_UNUSED OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition: LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
Definition: LLVMTypes.cpp:738
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition: LLVMTypes.cpp:90
static LLVM_ATTRIBUTE_UNUSED LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static constexpr llvm::StringRef kSpirvPrefix
Definition: LLVMTypes.cpp:665
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition: LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition: LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
Definition: LLVMTypes.cpp:666
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
Definition: LLVMTypes.cpp:277
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:542
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:598
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:560
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:838
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
Definition: LLVMTypes.cpp:854
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
Definition: LLVMTypes.cpp:832
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:811
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
Definition: LLVMTypes.cpp:702
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition: LLVMTypes.h:136
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
Definition: LLVMTypes.cpp:263
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:793
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:806
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:824
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
::llvm::MapVector<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface > DataLayoutIdentifiedEntryMap
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424