MLIR  20.0.0git
LLVMTypes.cpp
Go to the documentation of this file.
1 //===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
10 // correspond to the LLVM IR type system.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TypeDetail.h"
15 
18 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/TypeSupport.h"
21 
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/TypeSize.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::LLVM;
29 
30 constexpr const static uint64_t kBitsInByte = 8;
31 
32 //===----------------------------------------------------------------------===//
33 // custom<FunctionTypes>
34 //===----------------------------------------------------------------------===//
35 
36 static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> &params,
37  bool &isVarArg) {
38  isVarArg = false;
39  // `(` `)`
40  if (succeeded(p.parseOptionalRParen()))
41  return success();
42 
43  // `(` `...` `)`
44  if (succeeded(p.parseOptionalEllipsis())) {
45  isVarArg = true;
46  return p.parseRParen();
47  }
48 
49  // type (`,` type)* (`,` `...`)?
50  Type type;
51  if (parsePrettyLLVMType(p, type))
52  return failure();
53  params.push_back(type);
54  while (succeeded(p.parseOptionalComma())) {
55  if (succeeded(p.parseOptionalEllipsis())) {
56  isVarArg = true;
57  return p.parseRParen();
58  }
59  if (parsePrettyLLVMType(p, type))
60  return failure();
61  params.push_back(type);
62  }
63  return p.parseRParen();
64 }
65 
67  bool isVarArg) {
68  llvm::interleaveComma(params, p,
69  [&](Type type) { printPrettyLLVMType(p, type); });
70  if (isVarArg) {
71  if (!params.empty())
72  p << ", ";
73  p << "...";
74  }
75  p << ')';
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // custom<ExtTypeParams>
80 //===----------------------------------------------------------------------===//
81 
82 /// Parses the parameter list for a target extension type. The parameter list
83 /// contains an optional list of type parameters, followed by an optional list
84 /// of integer parameters. Type and integer parameters cannot be interleaved in
85 /// the list.
86 /// extTypeParams ::= typeList? | intList? | (typeList "," intList)
87 /// typeList ::= type ("," type)*
88 /// intList ::= integer ("," integer)*
89 static ParseResult
91  SmallVectorImpl<unsigned int> &intParams) {
92  bool parseType = true;
93  auto typeOrIntParser = [&]() -> ParseResult {
94  unsigned int i;
95  auto intResult = p.parseOptionalInteger(i);
96  if (intResult.has_value() && !failed(*intResult)) {
97  // Successfully parsed an integer.
98  intParams.push_back(i);
99  // After the first integer was successfully parsed, no
100  // more types can be parsed.
101  parseType = false;
102  return success();
103  }
104  if (parseType) {
105  Type t;
106  if (!parsePrettyLLVMType(p, t)) {
107  // Successfully parsed a type.
108  typeParams.push_back(t);
109  return success();
110  }
111  }
112  return failure();
113  };
114  if (p.parseCommaSeparatedList(typeOrIntParser)) {
116  "failed to parse parameter list for target extension type");
117  return failure();
118  }
119  return success();
120 }
121 
122 static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
123  ArrayRef<unsigned int> intParams) {
124  p << typeParams;
125  if (!typeParams.empty() && !intParams.empty())
126  p << ", ";
127 
128  p << intParams;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // ODS-Generated Definitions
133 //===----------------------------------------------------------------------===//
134 
135 /// These are unused for now.
136 /// TODO: Move over to these once more types have been migrated to TypeDef.
137 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
138 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
139 LLVM_ATTRIBUTE_UNUSED static LogicalResult
141 
142 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
146 
147 //===----------------------------------------------------------------------===//
148 // LLVMArrayType
149 //===----------------------------------------------------------------------===//
150 
151 bool LLVMArrayType::isValidElementType(Type type) {
152  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
153  LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
154  type);
155 }
156 
157 LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
158  assert(elementType && "expected non-null subtype");
159  return Base::get(elementType.getContext(), elementType, numElements);
160 }
161 
162 LLVMArrayType
163 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
164  Type elementType, uint64_t numElements) {
165  assert(elementType && "expected non-null subtype");
166  return Base::getChecked(emitError, elementType.getContext(), elementType,
167  numElements);
168 }
169 
170 LogicalResult
172  Type elementType, uint64_t numElements) {
173  if (!isValidElementType(elementType))
174  return emitError() << "invalid array element type: " << elementType;
175  return success();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // DataLayoutTypeInterface
180 
181 llvm::TypeSize
182 LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
183  DataLayoutEntryListRef params) const {
184  return llvm::TypeSize::getFixed(kBitsInByte *
185  getTypeSize(dataLayout, params));
186 }
187 
188 llvm::TypeSize LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
189  DataLayoutEntryListRef params) const {
190  return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
191  dataLayout.getTypeABIAlignment(getElementType())) *
192  getNumElements();
193 }
194 
195 uint64_t LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
196  DataLayoutEntryListRef params) const {
197  return dataLayout.getTypeABIAlignment(getElementType());
198 }
199 
200 uint64_t
201 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
202  DataLayoutEntryListRef params) const {
203  return dataLayout.getTypePreferredAlignment(getElementType());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Function type.
208 //===----------------------------------------------------------------------===//
209 
210 bool LLVMFunctionType::isValidArgumentType(Type type) {
211  return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
212 }
213 
214 bool LLVMFunctionType::isValidResultType(Type type) {
215  return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
216 }
217 
218 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
219  bool isVarArg) {
220  assert(result && "expected non-null result");
221  return Base::get(result.getContext(), result, arguments, isVarArg);
222 }
223 
224 LLVMFunctionType
225 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
226  Type result, ArrayRef<Type> arguments,
227  bool isVarArg) {
228  assert(result && "expected non-null result");
229  return Base::getChecked(emitError, result.getContext(), result, arguments,
230  isVarArg);
231 }
232 
233 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
234  TypeRange results) const {
235  assert(results.size() == 1 && "expected a single result type");
236  return get(results[0], llvm::to_vector(inputs), isVarArg());
237 }
238 
240  return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
241 }
242 
243 LogicalResult
245  Type result, ArrayRef<Type> arguments, bool) {
246  if (!isValidResultType(result))
247  return emitError() << "invalid function result type: " << result;
248 
249  for (Type arg : arguments)
250  if (!isValidArgumentType(arg))
251  return emitError() << "invalid function argument type: " << arg;
252 
253  return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // DataLayoutTypeInterface
258 
259 constexpr const static uint64_t kDefaultPointerSizeBits = 64;
260 constexpr const static uint64_t kDefaultPointerAlignment = 8;
261 
262 std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
263  PtrDLEntryPos pos) {
264  auto spec = cast<DenseIntElementsAttr>(attr);
265  auto idx = static_cast<int64_t>(pos);
266  if (idx >= spec.size())
267  return std::nullopt;
268  return spec.getValues<uint64_t>()[idx];
269 }
270 
271 /// Returns the part of the data layout entry that corresponds to `pos` for the
272 /// given `type` by interpreting the list of entries `params`. For the pointer
273 /// type in the default address space, returns the default value if the entries
274 /// do not provide a custom one, for other address spaces returns std::nullopt.
275 static std::optional<uint64_t>
277  PtrDLEntryPos pos) {
278  // First, look for the entry for the pointer in the current address space.
279  Attribute currentEntry;
280  for (DataLayoutEntryInterface entry : params) {
281  if (!entry.isTypeEntry())
282  continue;
283  if (cast<LLVMPointerType>(cast<Type>(entry.getKey())).getAddressSpace() ==
284  type.getAddressSpace()) {
285  currentEntry = entry.getValue();
286  break;
287  }
288  }
289  if (currentEntry) {
290  std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
291  // If the optional `PtrDLEntryPos::Index` entry is not available, use the
292  // pointer size as the index bitwidth.
293  if (!value && pos == PtrDLEntryPos::Index)
294  value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
295  bool isSizeOrIndex =
296  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
297  return *value / (isSizeOrIndex ? 1 : kBitsInByte);
298  }
299 
300  // If not found, and this is the pointer to the default memory space, assume
301  // 64-bit pointers.
302  if (type.getAddressSpace() == 0) {
303  bool isSizeOrIndex =
304  pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
305  return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
306  }
307 
308  return std::nullopt;
309 }
310 
311 llvm::TypeSize
312 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
313  DataLayoutEntryListRef params) const {
314  if (std::optional<uint64_t> size =
316  return llvm::TypeSize::getFixed(*size);
317 
318  // For other memory spaces, use the size of the pointer to the default memory
319  // space.
320  return dataLayout.getTypeSizeInBits(get(getContext()));
321 }
322 
323 uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
324  DataLayoutEntryListRef params) const {
325  if (std::optional<uint64_t> alignment =
327  return *alignment;
328 
329  return dataLayout.getTypeABIAlignment(get(getContext()));
330 }
331 
332 uint64_t
333 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
334  DataLayoutEntryListRef params) const {
335  if (std::optional<uint64_t> alignment =
337  return *alignment;
338 
339  return dataLayout.getTypePreferredAlignment(get(getContext()));
340 }
341 
342 std::optional<uint64_t>
344  DataLayoutEntryListRef params) const {
345  if (std::optional<uint64_t> indexBitwidth =
347  return *indexBitwidth;
348 
349  return dataLayout.getTypeIndexBitwidth(get(getContext()));
350 }
351 
352 bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
353  DataLayoutEntryListRef newLayout) const {
354  for (DataLayoutEntryInterface newEntry : newLayout) {
355  if (!newEntry.isTypeEntry())
356  continue;
357  uint64_t size = kDefaultPointerSizeBits;
358  uint64_t abi = kDefaultPointerAlignment;
359  auto newType =
360  llvm::cast<LLVMPointerType>(llvm::cast<Type>(newEntry.getKey()));
361  const auto *it =
362  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
363  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
364  return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
365  newType.getAddressSpace();
366  }
367  return false;
368  });
369  if (it == oldLayout.end()) {
370  llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
371  if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
372  return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
373  }
374  return false;
375  });
376  }
377  if (it != oldLayout.end()) {
380  }
381 
382  Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
383  uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
384  uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
385  if (size != newSize || abi < newAbi || abi % newAbi != 0)
386  return false;
387  }
388  return true;
389 }
390 
392  Location loc) const {
393  for (DataLayoutEntryInterface entry : entries) {
394  if (!entry.isTypeEntry())
395  continue;
396  auto key = llvm::cast<Type>(entry.getKey());
397  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
398  if (!values || (values.size() != 3 && values.size() != 4)) {
399  return emitError(loc)
400  << "expected layout attribute for " << key
401  << " to be a dense integer elements attribute with 3 or 4 "
402  "elements";
403  }
404  if (!values.getElementType().isInteger(64))
405  return emitError(loc) << "expected i64 parameters for " << key;
406 
409  return emitError(loc) << "preferred alignment is expected to be at least "
410  "as large as ABI alignment";
411  }
412  }
413  return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // Struct type.
418 //===----------------------------------------------------------------------===//
419 
420 bool LLVMStructType::isValidElementType(Type type) {
421  return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
422  LLVMFunctionType, LLVMTokenType>(type);
423 }
424 
425 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
426  StringRef name) {
427  return Base::get(context, name, /*opaque=*/false);
428 }
429 
430 LLVMStructType LLVMStructType::getIdentifiedChecked(
432  StringRef name) {
433  return Base::getChecked(emitError, context, name, /*opaque=*/false);
434 }
435 
436 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
437  StringRef name,
438  ArrayRef<Type> elements,
439  bool isPacked) {
440  std::string stringName = name.str();
441  unsigned counter = 0;
442  do {
443  auto type = LLVMStructType::getIdentified(context, stringName);
444  if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
445  counter += 1;
446  stringName = (Twine(name) + "." + std::to_string(counter)).str();
447  continue;
448  }
449  return type;
450  } while (true);
451 }
452 
453 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
454  ArrayRef<Type> types, bool isPacked) {
455  return Base::get(context, types, isPacked);
456 }
457 
458 LLVMStructType
459 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
460  MLIRContext *context, ArrayRef<Type> types,
461  bool isPacked) {
462  return Base::getChecked(emitError, context, types, isPacked);
463 }
464 
465 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
466  return Base::get(context, name, /*opaque=*/true);
467 }
468 
469 LLVMStructType
470 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
471  MLIRContext *context, StringRef name) {
472  return Base::getChecked(emitError, context, name, /*opaque=*/true);
473 }
474 
475 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
476  assert(isIdentified() && "can only set bodies of identified structs");
477  assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
478  "expected valid body types");
479  return Base::mutate(types, isPacked);
480 }
481 
482 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
483 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
484 bool LLVMStructType::isOpaque() const {
485  return getImpl()->isIdentified() &&
486  (getImpl()->isOpaque() || !getImpl()->isInitialized());
487 }
488 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
489 StringRef LLVMStructType::getName() const { return getImpl()->getIdentifier(); }
490 ArrayRef<Type> LLVMStructType::getBody() const {
491  return isIdentified() ? getImpl()->getIdentifiedStructBody()
492  : getImpl()->getTypeList();
493 }
494 
495 LogicalResult
496 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
497  bool) {
498  return success();
499 }
500 
501 LogicalResult
502 LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
503  ArrayRef<Type> types, bool) {
504  for (Type t : types)
505  if (!isValidElementType(t))
506  return emitError() << "invalid LLVM structure element type: " << t;
507 
508  return success();
509 }
510 
511 llvm::TypeSize
512 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
513  DataLayoutEntryListRef params) const {
514  auto structSize = llvm::TypeSize::getFixed(0);
515  uint64_t structAlignment = 1;
516  for (Type element : getBody()) {
517  uint64_t elementAlignment =
518  isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
519  // Add padding to the struct size to align it to the abi alignment of the
520  // element type before than adding the size of the element.
521  structSize = llvm::alignTo(structSize, elementAlignment);
522  structSize += dataLayout.getTypeSize(element);
523 
524  // The alignment requirement of a struct is equal to the strictest alignment
525  // requirement of its elements.
526  structAlignment = std::max(elementAlignment, structAlignment);
527  }
528  // At the end, add padding to the struct to satisfy its own alignment
529  // requirement. Otherwise structs inside of arrays would be misaligned.
530  structSize = llvm::alignTo(structSize, structAlignment);
531  return structSize * kBitsInByte;
532 }
533 
534 namespace {
535 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
536 } // namespace
537 
538 static std::optional<uint64_t>
540  StructDLEntryPos pos) {
541  const auto *currentEntry =
542  llvm::find_if(params, [](DataLayoutEntryInterface entry) {
543  return entry.isTypeEntry();
544  });
545  if (currentEntry == params.end())
546  return std::nullopt;
547 
548  auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
549  if (pos == StructDLEntryPos::Preferred &&
550  attr.size() <= static_cast<int64_t>(StructDLEntryPos::Preferred))
551  // If no preferred was specified, fall back to abi alignment
552  pos = StructDLEntryPos::Abi;
553 
554  return attr.getValues<uint64_t>()[static_cast<size_t>(pos)];
555 }
556 
557 static uint64_t calculateStructAlignment(const DataLayout &dataLayout,
558  DataLayoutEntryListRef params,
559  LLVMStructType type,
560  StructDLEntryPos pos) {
561  // Packed structs always have an abi alignment of 1
562  if (pos == StructDLEntryPos::Abi && type.isPacked()) {
563  return 1;
564  }
565 
566  // The alignment requirement of a struct is equal to the strictest alignment
567  // requirement of its elements.
568  uint64_t structAlignment = 1;
569  for (Type iter : type.getBody()) {
570  structAlignment =
571  std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
572  }
573 
574  // Entries are only allowed to be stricter than the required alignment
575  if (std::optional<uint64_t> entryResult =
576  getStructDataLayoutEntry(params, type, pos))
577  return std::max(*entryResult / kBitsInByte, structAlignment);
578 
579  return structAlignment;
580 }
581 
582 uint64_t LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
583  DataLayoutEntryListRef params) const {
584  return calculateStructAlignment(dataLayout, params, *this,
585  StructDLEntryPos::Abi);
586 }
587 
588 uint64_t
589 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
590  DataLayoutEntryListRef params) const {
591  return calculateStructAlignment(dataLayout, params, *this,
592  StructDLEntryPos::Preferred);
593 }
594 
595 static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
596  return llvm::cast<DenseIntElementsAttr>(attr)
597  .getValues<uint64_t>()[static_cast<size_t>(pos)];
598 }
599 
600 bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
601  DataLayoutEntryListRef newLayout) const {
602  for (DataLayoutEntryInterface newEntry : newLayout) {
603  if (!newEntry.isTypeEntry())
604  continue;
605 
606  const auto *previousEntry =
607  llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
608  return entry.isTypeEntry();
609  });
610  if (previousEntry == oldLayout.end())
611  continue;
612 
613  uint64_t abi = extractStructSpecValue(previousEntry->getValue(),
614  StructDLEntryPos::Abi);
615  uint64_t newAbi =
616  extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
617  if (abi < newAbi || abi % newAbi != 0)
618  return false;
619  }
620  return true;
621 }
622 
624  Location loc) const {
625  for (DataLayoutEntryInterface entry : entries) {
626  if (!entry.isTypeEntry())
627  continue;
628 
629  auto key = llvm::cast<LLVMStructType>(llvm::cast<Type>(entry.getKey()));
630  auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
631  if (!values || (values.size() != 2 && values.size() != 1)) {
632  return emitError(loc)
633  << "expected layout attribute for "
634  << llvm::cast<Type>(entry.getKey())
635  << " to be a dense integer elements attribute of 1 or 2 elements";
636  }
637  if (!values.getElementType().isInteger(64))
638  return emitError(loc) << "expected i64 entries for " << key;
639 
640  if (key.isIdentified() || !key.getBody().empty()) {
641  return emitError(loc) << "unexpected layout attribute for struct " << key;
642  }
643 
644  if (values.size() == 1)
645  continue;
646 
647  if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
648  extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
649  return emitError(loc) << "preferred alignment is expected to be at least "
650  "as large as ABI alignment";
651  }
652  }
653  return mlir::success();
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // Vector types.
658 //===----------------------------------------------------------------------===//
659 
660 /// Verifies that the type about to be constructed is well-formed.
661 template <typename VecTy>
662 static LogicalResult
664  Type elementType, unsigned numElements) {
665  if (numElements == 0)
666  return emitError() << "the number of vector elements must be positive";
667 
668  if (!VecTy::isValidElementType(elementType))
669  return emitError() << "invalid vector element type";
670 
671  return success();
672 }
673 
674 LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
675  unsigned numElements) {
676  assert(elementType && "expected non-null subtype");
677  return Base::get(elementType.getContext(), elementType, numElements);
678 }
679 
680 LLVMFixedVectorType
681 LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
682  Type elementType, unsigned numElements) {
683  assert(elementType && "expected non-null subtype");
684  return Base::getChecked(emitError, elementType.getContext(), elementType,
685  numElements);
686 }
687 
688 bool LLVMFixedVectorType::isValidElementType(Type type) {
689  return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
690 }
691 
692 LogicalResult
694  Type elementType, unsigned numElements) {
695  return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
696  emitError, elementType, numElements);
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // LLVMScalableVectorType.
701 //===----------------------------------------------------------------------===//
702 
703 LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
704  unsigned minNumElements) {
705  assert(elementType && "expected non-null subtype");
706  return Base::get(elementType.getContext(), elementType, minNumElements);
707 }
708 
709 LLVMScalableVectorType
710 LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
711  Type elementType, unsigned minNumElements) {
712  assert(elementType && "expected non-null subtype");
713  return Base::getChecked(emitError, elementType.getContext(), elementType,
714  minNumElements);
715 }
716 
717 bool LLVMScalableVectorType::isValidElementType(Type type) {
718  if (auto intType = llvm::dyn_cast<IntegerType>(type))
719  return intType.isSignless();
720 
721  return isCompatibleFloatingPointType(type) ||
722  llvm::isa<LLVMPointerType>(type);
723 }
724 
725 LogicalResult
727  Type elementType, unsigned numElements) {
728  return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
729  emitError, elementType, numElements);
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // LLVMTargetExtType.
734 //===----------------------------------------------------------------------===//
735 
736 static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
737 static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
738 
739 bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
740  // See llvm/lib/IR/Type.cpp for reference.
741  uint64_t properties = 0;
742 
743  if (getExtTypeName().starts_with(kSpirvPrefix))
744  properties |=
745  (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
746 
747  return (properties & prop) == prop;
748 }
749 
750 bool LLVM::LLVMTargetExtType::supportsMemOps() const {
751  // See llvm/lib/IR/Type.cpp for reference.
752  if (getExtTypeName().starts_with(kSpirvPrefix))
753  return true;
754 
755  if (getExtTypeName() == kArmSVCount)
756  return true;
757 
758  return false;
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // Utility functions.
763 //===----------------------------------------------------------------------===//
764 
766  // clang-format off
767  if (llvm::isa<
768  BFloat16Type,
769  Float16Type,
770  Float32Type,
771  Float64Type,
772  Float80Type,
773  Float128Type,
774  LLVMArrayType,
775  LLVMFunctionType,
776  LLVMLabelType,
777  LLVMMetadataType,
778  LLVMPPCFP128Type,
779  LLVMPointerType,
780  LLVMStructType,
781  LLVMTokenType,
782  LLVMFixedVectorType,
783  LLVMScalableVectorType,
784  LLVMTargetExtType,
785  LLVMVoidType,
786  LLVMX86AMXType
787  >(type)) {
788  // clang-format on
789  return true;
790  }
791 
792  // Only signless integers are compatible.
793  if (auto intType = llvm::dyn_cast<IntegerType>(type))
794  return intType.isSignless();
795 
796  // 1D vector types are compatible.
797  if (auto vecType = llvm::dyn_cast<VectorType>(type))
798  return vecType.getRank() == 1;
799 
800  return false;
801 }
802 
803 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
804  if (!compatibleTypes.insert(type).second)
805  return true;
806 
807  auto isCompatible = [&](Type type) {
808  return isCompatibleImpl(type, compatibleTypes);
809  };
810 
811  bool result =
813  .Case<LLVMStructType>([&](auto structType) {
814  return llvm::all_of(structType.getBody(), isCompatible);
815  })
816  .Case<LLVMFunctionType>([&](auto funcType) {
817  return isCompatible(funcType.getReturnType()) &&
818  llvm::all_of(funcType.getParams(), isCompatible);
819  })
820  .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
821  .Case<VectorType>([&](auto vecType) {
822  return vecType.getRank() == 1 &&
823  isCompatible(vecType.getElementType());
824  })
825  .Case<LLVMPointerType>([&](auto pointerType) { return true; })
826  .Case<LLVMTargetExtType>([&](auto extType) {
827  return llvm::all_of(extType.getTypeParams(), isCompatible);
828  })
829  // clang-format off
830  .Case<
831  LLVMFixedVectorType,
832  LLVMScalableVectorType,
833  LLVMArrayType
834  >([&](auto containerType) {
835  return isCompatible(containerType.getElementType());
836  })
837  .Case<
838  BFloat16Type,
839  Float16Type,
840  Float32Type,
841  Float64Type,
842  Float80Type,
843  Float128Type,
844  LLVMLabelType,
845  LLVMMetadataType,
846  LLVMPPCFP128Type,
847  LLVMTokenType,
848  LLVMVoidType,
849  LLVMX86AMXType
850  >([](Type) { return true; })
851  // clang-format on
852  .Default([](Type) { return false; });
853 
854  if (!result)
855  compatibleTypes.erase(type);
856 
857  return result;
858 }
859 
861  if (auto *llvmDialect =
862  type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
863  return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
864 
865  DenseSet<Type> localCompatibleTypes;
866  return isCompatibleImpl(type, localCompatibleTypes);
867 }
868 
870  return LLVMDialect::isCompatibleType(type);
871 }
872 
874  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
875  Float80Type, Float128Type, LLVMPPCFP128Type>(type);
876 }
877 
879  if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
880  return true;
881 
882  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
883  if (vecType.getRank() != 1)
884  return false;
885  Type elementType = vecType.getElementType();
886  if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
887  return intType.isSignless();
888  return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
889  Float80Type, Float128Type>(elementType);
890  }
891  return false;
892 }
893 
895  return llvm::TypeSwitch<Type, Type>(type)
896  .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
897  [](auto ty) { return ty.getElementType(); })
898  .Default([](Type) -> Type {
899  llvm_unreachable("incompatible with LLVM vector type");
900  });
901 }
902 
903 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
905  .Case([](VectorType ty) {
906  if (ty.isScalable())
907  return llvm::ElementCount::getScalable(ty.getNumElements());
908  return llvm::ElementCount::getFixed(ty.getNumElements());
909  })
910  .Case([](LLVMFixedVectorType ty) {
911  return llvm::ElementCount::getFixed(ty.getNumElements());
912  })
913  .Case([](LLVMScalableVectorType ty) {
914  return llvm::ElementCount::getScalable(ty.getMinNumElements());
915  })
916  .Default([](Type) -> llvm::ElementCount {
917  llvm_unreachable("incompatible with LLVM vector type");
918  });
919 }
920 
922  assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
923  vectorType)) &&
924  "expected LLVM-compatible vector type");
925  return !llvm::isa<LLVMFixedVectorType>(vectorType) &&
926  (llvm::isa<LLVMScalableVectorType>(vectorType) ||
927  llvm::cast<VectorType>(vectorType).isScalable());
928 }
929 
930 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
931  bool isScalable) {
932  bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
933  bool useBuiltIn = VectorType::isValidElementType(elementType);
934  (void)useBuiltIn;
935  assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
936  "to be either builtin or LLVM dialect type");
937  if (useLLVM) {
938  if (isScalable)
939  return LLVMScalableVectorType::get(elementType, numElements);
940  return LLVMFixedVectorType::get(elementType, numElements);
941  }
942 
943  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
944  // scalable/non-scalable.
945  return VectorType::get(numElements, elementType, {isScalable});
946 }
947 
949  const llvm::ElementCount &numElements) {
950  if (numElements.isScalable())
951  return getVectorType(elementType, numElements.getKnownMinValue(),
952  /*isScalable=*/true);
953  return getVectorType(elementType, numElements.getFixedValue(),
954  /*isScalable=*/false);
955 }
956 
957 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
958  bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
959  bool useBuiltIn = VectorType::isValidElementType(elementType);
960  (void)useBuiltIn;
961  assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
962  "to be either builtin or LLVM dialect type");
963  if (useLLVM)
964  return LLVMFixedVectorType::get(elementType, numElements);
965  return VectorType::get(numElements, elementType);
966 }
967 
968 Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
969  bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
970  bool useBuiltIn = VectorType::isValidElementType(elementType);
971  (void)useBuiltIn;
972  assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
973  "type to be either builtin or LLVM dialect "
974  "type");
975  if (useLLVM)
976  return LLVMScalableVectorType::get(elementType, numElements);
977 
978  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
979  // scalable/non-scalable.
980  return VectorType::get(numElements, elementType, /*scalableDims=*/true);
981 }
982 
984  assert(isCompatibleType(type) &&
985  "expected a type compatible with the LLVM dialect");
986 
988  .Case<BFloat16Type, Float16Type>(
989  [](Type) { return llvm::TypeSize::getFixed(16); })
990  .Case<Float32Type>([](Type) { return llvm::TypeSize::getFixed(32); })
991  .Case<Float64Type>([](Type) { return llvm::TypeSize::getFixed(64); })
992  .Case<Float80Type>([](Type) { return llvm::TypeSize::getFixed(80); })
993  .Case<Float128Type>([](Type) { return llvm::TypeSize::getFixed(128); })
994  .Case<IntegerType>([](IntegerType intTy) {
995  return llvm::TypeSize::getFixed(intTy.getWidth());
996  })
997  .Case<LLVMPPCFP128Type>(
998  [](Type) { return llvm::TypeSize::getFixed(128); })
999  .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) {
1000  llvm::TypeSize elementSize =
1001  getPrimitiveTypeSizeInBits(t.getElementType());
1002  return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
1003  elementSize.isScalable());
1004  })
1005  .Case<VectorType>([](VectorType t) {
1006  assert(isCompatibleVectorType(t) &&
1007  "unexpected incompatible with LLVM vector type");
1008  llvm::TypeSize elementSize =
1009  getPrimitiveTypeSizeInBits(t.getElementType());
1010  return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(),
1011  elementSize.isScalable());
1012  })
1013  .Default([](Type ty) {
1014  assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
1015  LLVMTokenType, LLVMStructType, LLVMArrayType,
1016  LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
1017  ty)) &&
1018  "unexpected missing support for primitive type");
1019  return llvm::TypeSize::getFixed(0);
1020  });
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // LLVMDialect
1025 //===----------------------------------------------------------------------===//
1026 
1027 void LLVMDialect::registerTypes() {
1028  addTypes<
1029 #define GET_TYPEDEF_LIST
1030 #include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"
1031  >();
1032 }
1033 
1035  return detail::parseType(parser);
1036 }
1037 
1038 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
1039  return detail::printType(type, os);
1040 }
static LogicalResult verifyEntries(function_ref< InFlightDiagnostic()> emitError, ArrayRef< DataLayoutEntryInterface > entries, bool allowTypes=true)
Verify entries, with the option to disallow types as keys.
Definition: DLTI.cpp:136
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
constexpr static const uint64_t kDefaultPointerAlignment
Definition: LLVMTypes.cpp:260
static void printExtTypeParams(AsmPrinter &p, ArrayRef< Type > typeParams, ArrayRef< unsigned int > intParams)
Definition: LLVMTypes.cpp:122
constexpr static const uint64_t kDefaultPointerSizeBits
Definition: LLVMTypes.cpp:259
static LLVM_ATTRIBUTE_UNUSED OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
These are unused for now.
static void printFunctionTypes(AsmPrinter &p, ArrayRef< Type > params, bool isVarArg)
Definition: LLVMTypes.cpp:66
static bool isCompatibleImpl(Type type, DenseSet< Type > &compatibleTypes)
Definition: LLVMTypes.cpp:803
static ParseResult parseExtTypeParams(AsmParser &p, SmallVectorImpl< Type > &typeParams, SmallVectorImpl< unsigned int > &intParams)
Parses the parameter list for a target extension type.
Definition: LLVMTypes.cpp:90
static LLVM_ATTRIBUTE_UNUSED LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
static constexpr llvm::StringRef kSpirvPrefix
Definition: LLVMTypes.cpp:736
static LogicalResult verifyVectorConstructionInvariants(function_ref< InFlightDiagnostic()> emitError, Type elementType, unsigned numElements)
Verifies that the type about to be constructed is well-formed.
Definition: LLVMTypes.cpp:663
static ParseResult parseFunctionTypes(AsmParser &p, SmallVector< Type > &params, bool &isVarArg)
Definition: LLVMTypes.cpp:36
constexpr static const uint64_t kBitsInByte
Definition: LLVMTypes.cpp:30
static constexpr llvm::StringRef kArmSVCount
Definition: LLVMTypes.cpp:737
static std::optional< uint64_t > getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type, PtrDLEntryPos pos)
Returns the part of the data layout entry that corresponds to pos for the given type by interpreting ...
Definition: LLVMTypes.cpp:276
static std::optional< uint64_t > getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:539
static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:595
static uint64_t calculateStructAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params, LLVMStructType type, StructDLEntryPos pos)
Definition: LLVMTypes.cpp:557
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:930
llvm::TypeSize getPrimitiveTypeSizeInBits(Type type)
Returns the size of the given primitive LLVM dialect-compatible type (including vectors) in bits,...
Definition: LLVMTypes.cpp:983
void printPrettyLLVMType(AsmPrinter &p, Type type)
Print any MLIR type or a concise syntax for LLVM types.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
Definition: LLVMTypes.cpp:921
ParseResult parsePrettyLLVMType(AsmParser &p, Type &type)
Parse any MLIR type or a concise syntax for LLVM types.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:878
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
Definition: LLVMTypes.cpp:765
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:957
PtrDLEntryPos
The positions of different values in the data layout entry for pointers.
Definition: LLVMTypes.h:149
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
Definition: LLVMTypes.cpp:262
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
Type getScalableVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:968
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:873
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:903
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:894
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425