MLIR  16.0.0git
LLVMTypeSyntax.cpp
Go to the documentation of this file.
1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22 
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
26 static void dispatchPrint(AsmPrinter &printer, Type type) {
27  if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28  return mlir::LLVM::detail::printType(type, printer);
29  printer.printType(type);
30 }
31 
32 /// Returns the keyword to use for the given type.
33 static StringRef getTypeKeyword(Type type) {
34  return TypeSwitch<Type, StringRef>(type)
35  .Case<LLVMVoidType>([&](Type) { return "void"; })
36  .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37  .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38  .Case<LLVMTokenType>([&](Type) { return "token"; })
39  .Case<LLVMLabelType>([&](Type) { return "label"; })
40  .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41  .Case<LLVMFunctionType>([&](Type) { return "func"; })
42  .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43  .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44  [&](Type) { return "vec"; })
45  .Case<LLVMArrayType>([&](Type) { return "array"; })
46  .Case<LLVMStructType>([&](Type) { return "struct"; })
47  .Default([](Type) -> StringRef {
48  llvm_unreachable("unexpected 'llvm' type kind");
49  });
50 }
51 
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
54 static void printStructType(AsmPrinter &printer, LLVMStructType type) {
55  // This keeps track of the names of identified structure types that are
56  // currently being printed. Since such types can refer themselves, this
57  // tracking is necessary to stop the recursion: the current function may be
58  // called recursively from AsmPrinter::printType after the appropriate
59  // dispatch. We maintain the invariant of this storage being modified
60  // exclusively in this function, and at most one name being added per call.
61  // TODO: consider having such functionality inside AsmPrinter.
62  thread_local SetVector<StringRef> knownStructNames;
63  unsigned stackSize = knownStructNames.size();
64  (void)stackSize;
65  auto guard = llvm::make_scope_exit([&]() {
66  assert(knownStructNames.size() == stackSize &&
67  "malformed identified stack when printing recursive structs");
68  });
69 
70  printer << "<";
71  if (type.isIdentified()) {
72  printer << '"' << type.getName() << '"';
73  // If we are printing a reference to one of the enclosing structs, just
74  // print the name and stop to avoid infinitely long output.
75  if (knownStructNames.count(type.getName())) {
76  printer << '>';
77  return;
78  }
79  printer << ", ";
80  }
81 
82  if (type.isIdentified() && type.isOpaque()) {
83  printer << "opaque>";
84  return;
85  }
86 
87  if (type.isPacked())
88  printer << "packed ";
89 
90  // Put the current type on stack to avoid infinite recursion.
91  printer << '(';
92  if (type.isIdentified())
93  knownStructNames.insert(type.getName());
94  llvm::interleaveComma(type.getBody(), printer.getStream(),
95  [&](Type subtype) { dispatchPrint(printer, subtype); });
96  if (type.isIdentified())
97  knownStructNames.pop_back();
98  printer << ')';
99  printer << '>';
100 }
101 
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
104 static void printArrayOrVectorType(AsmPrinter &printer, TypeTy type) {
105  printer << '<' << type.getNumElements() << " x ";
106  dispatchPrint(printer, type.getElementType());
107  printer << '>';
108 }
109 
110 /// Prints a function type.
111 static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
112  printer << '<';
113  dispatchPrint(printer, funcType.getReturnType());
114  printer << " (";
115  llvm::interleaveComma(
116  funcType.getParams(), printer.getStream(),
117  [&printer](Type subtype) { dispatchPrint(printer, subtype); });
118  if (funcType.isVarArg()) {
119  if (funcType.getNumParams() != 0)
120  printer << ", ";
121  printer << "...";
122  }
123  printer << ")>";
124 }
125 
126 /// Prints the given LLVM dialect type recursively. This leverages closedness of
127 /// the LLVM dialect type system to avoid printing the dialect prefix
128 /// repeatedly. For recursive structures, only prints the name of the structure
129 /// when printing a self-reference. Note that this does not apply to sibling
130 /// references. For example,
131 /// struct<"a", (ptr<struct<"a">>)>
132 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
133 /// ptr<struct<"b", (ptr<struct<"c">>)>>)>
134 /// note that "b" is printed twice.
135 void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
136  if (!type) {
137  printer << "<<NULL-TYPE>>";
138  return;
139  }
140 
141  printer << getTypeKeyword(type);
142 
143  if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
144  if (ptrType.isOpaque()) {
145  if (ptrType.getAddressSpace() != 0)
146  printer << '<' << ptrType.getAddressSpace() << '>';
147  return;
148  }
149 
150  printer << '<';
151  dispatchPrint(printer, ptrType.getElementType());
152  if (ptrType.getAddressSpace() != 0)
153  printer << ", " << ptrType.getAddressSpace();
154  printer << '>';
155  return;
156  }
157 
158  if (auto arrayType = type.dyn_cast<LLVMArrayType>())
159  return printArrayOrVectorType(printer, arrayType);
160  if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
161  return printArrayOrVectorType(printer, vectorType);
162 
163  if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
164  printer << "<? x " << vectorType.getMinNumElements() << " x ";
165  dispatchPrint(printer, vectorType.getElementType());
166  printer << '>';
167  return;
168  }
169 
170  if (auto structType = type.dyn_cast<LLVMStructType>())
171  return printStructType(printer, structType);
172 
173  if (auto funcType = type.dyn_cast<LLVMFunctionType>())
174  return printFunctionType(printer, funcType);
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // Parsing.
179 //===----------------------------------------------------------------------===//
180 
181 static ParseResult dispatchParse(AsmParser &parser, Type &type);
182 
183 /// Parses an LLVM dialect function type.
184 /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
185 static LLVMFunctionType parseFunctionType(AsmParser &parser) {
186  SMLoc loc = parser.getCurrentLocation();
187  Type returnType;
188  if (parser.parseLess() || dispatchParse(parser, returnType) ||
189  parser.parseLParen())
190  return LLVMFunctionType();
191 
192  // Function type without arguments.
193  if (succeeded(parser.parseOptionalRParen())) {
194  if (succeeded(parser.parseGreater()))
195  return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
196  /*isVarArg=*/false);
197  return LLVMFunctionType();
198  }
199 
200  // Parse arguments.
201  SmallVector<Type, 8> argTypes;
202  do {
203  if (succeeded(parser.parseOptionalEllipsis())) {
204  if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
205  return LLVMFunctionType();
206  return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
207  /*isVarArg=*/true);
208  }
209 
210  Type arg;
211  if (dispatchParse(parser, arg))
212  return LLVMFunctionType();
213  argTypes.push_back(arg);
214  } while (succeeded(parser.parseOptionalComma()));
215 
216  if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
217  return LLVMFunctionType();
218  return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
219  /*isVarArg=*/false);
220 }
221 
222 /// Parses an LLVM dialect pointer type.
223 /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
224 /// | `ptr` (`<` integer `>`)?
225 static LLVMPointerType parsePointerType(AsmParser &parser) {
226  SMLoc loc = parser.getCurrentLocation();
227  Type elementType;
228  if (parser.parseOptionalLess()) {
229  return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
230  /*addressSpace=*/0);
231  }
232 
233  unsigned addressSpace = 0;
234  OptionalParseResult opr = parser.parseOptionalInteger(addressSpace);
235  if (opr.has_value()) {
236  if (failed(*opr) || parser.parseGreater())
237  return LLVMPointerType();
238  return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
239  addressSpace);
240  }
241 
242  if (dispatchParse(parser, elementType))
243  return LLVMPointerType();
244 
245  if (succeeded(parser.parseOptionalComma()) &&
246  failed(parser.parseInteger(addressSpace)))
247  return LLVMPointerType();
248  if (failed(parser.parseGreater()))
249  return LLVMPointerType();
250  return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
251 }
252 
253 /// Parses an LLVM dialect vector type.
254 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
255 /// Supports both fixed and scalable vectors.
256 static Type parseVectorType(AsmParser &parser) {
257  SmallVector<int64_t, 2> dims;
258  SMLoc dimPos, typePos;
259  Type elementType;
260  SMLoc loc = parser.getCurrentLocation();
261  if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
262  parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
263  parser.getCurrentLocation(&typePos) ||
264  dispatchParse(parser, elementType) || parser.parseGreater())
265  return Type();
266 
267  // We parsed a generic dimension list, but vectors only support two forms:
268  // - single non-dynamic entry in the list (fixed vector);
269  // - two elements, the first dynamic (indicated by -1) and the second
270  // non-dynamic (scalable vector).
271  if (dims.empty() || dims.size() > 2 ||
272  ((dims.size() == 2) ^ (dims[0] == -1)) ||
273  (dims.size() == 2 && dims[1] == -1)) {
274  parser.emitError(dimPos)
275  << "expected '? x <integer> x <type>' or '<integer> x <type>'";
276  return Type();
277  }
278 
279  bool isScalable = dims.size() == 2;
280  if (isScalable)
281  return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
282  if (elementType.isSignlessIntOrFloat()) {
283  parser.emitError(typePos)
284  << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
285  return Type();
286  }
287  return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
288 }
289 
290 /// Parses an LLVM dialect array type.
291 /// llvm-type ::= `array<` integer `x` llvm-type `>`
292 static LLVMArrayType parseArrayType(AsmParser &parser) {
293  SmallVector<int64_t, 1> dims;
294  SMLoc sizePos;
295  Type elementType;
296  SMLoc loc = parser.getCurrentLocation();
297  if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
298  parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
299  dispatchParse(parser, elementType) || parser.parseGreater())
300  return LLVMArrayType();
301 
302  if (dims.size() != 1) {
303  parser.emitError(sizePos) << "expected ? x <type>";
304  return LLVMArrayType();
305  }
306 
307  return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
308 }
309 
310 /// Attempts to set the body of an identified structure type. Reports a parsing
311 /// error at `subtypesLoc` in case of failure.
312 static LLVMStructType trySetStructBody(LLVMStructType type,
313  ArrayRef<Type> subtypes, bool isPacked,
314  AsmParser &parser, SMLoc subtypesLoc) {
315  for (Type t : subtypes) {
316  if (!LLVMStructType::isValidElementType(t)) {
317  parser.emitError(subtypesLoc)
318  << "invalid LLVM structure element type: " << t;
319  return LLVMStructType();
320  }
321  }
322 
323  if (succeeded(type.setBody(subtypes, isPacked)))
324  return type;
325 
326  parser.emitError(subtypesLoc)
327  << "identified type already used with a different body";
328  return LLVMStructType();
329 }
330 
331 /// Parses an LLVM dialect structure type.
332 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`?
333 /// `(` llvm-type-list `)` `>`
334 /// | `struct<` string-literal `>`
335 /// | `struct<` string-literal `, opaque>`
336 static LLVMStructType parseStructType(AsmParser &parser) {
337  // This keeps track of the names of identified structure types that are
338  // currently being parsed. Since such types can refer themselves, this
339  // tracking is necessary to stop the recursion: the current function may be
340  // called recursively from AsmParser::parseType after the appropriate
341  // dispatch. We maintain the invariant of this storage being modified
342  // exclusively in this function, and at most one name being added per call.
343  // TODO: consider having such functionality inside AsmParser.
344  thread_local SetVector<StringRef> knownStructNames;
345  unsigned stackSize = knownStructNames.size();
346  (void)stackSize;
347  auto guard = llvm::make_scope_exit([&]() {
348  assert(knownStructNames.size() == stackSize &&
349  "malformed identified stack when parsing recursive structs");
350  });
351 
352  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
353 
354  if (failed(parser.parseLess()))
355  return LLVMStructType();
356 
357  // If we are parsing a self-reference to a recursive struct, i.e. the parsing
358  // stack already contains a struct with the same identifier, bail out after
359  // the name.
360  std::string name;
361  bool isIdentified = succeeded(parser.parseOptionalString(&name));
362  if (isIdentified) {
363  if (knownStructNames.count(name)) {
364  if (failed(parser.parseGreater()))
365  return LLVMStructType();
366  return LLVMStructType::getIdentifiedChecked(
367  [loc] { return emitError(loc); }, loc.getContext(), name);
368  }
369  if (failed(parser.parseComma()))
370  return LLVMStructType();
371  }
372 
373  // Handle intentionally opaque structs.
374  SMLoc kwLoc = parser.getCurrentLocation();
375  if (succeeded(parser.parseOptionalKeyword("opaque"))) {
376  if (!isIdentified)
377  return parser.emitError(kwLoc, "only identified structs can be opaque"),
378  LLVMStructType();
379  if (failed(parser.parseGreater()))
380  return LLVMStructType();
381  auto type = LLVMStructType::getOpaqueChecked(
382  [loc] { return emitError(loc); }, loc.getContext(), name);
383  if (!type.isOpaque()) {
384  parser.emitError(kwLoc, "redeclaring defined struct as opaque");
385  return LLVMStructType();
386  }
387  return type;
388  }
389 
390  // Check for packedness.
391  bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
392  if (failed(parser.parseLParen()))
393  return LLVMStructType();
394 
395  // Fast pass for structs with zero subtypes.
396  if (succeeded(parser.parseOptionalRParen())) {
397  if (failed(parser.parseGreater()))
398  return LLVMStructType();
399  if (!isIdentified)
400  return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
401  loc.getContext(), {}, isPacked);
402  auto type = LLVMStructType::getIdentifiedChecked(
403  [loc] { return emitError(loc); }, loc.getContext(), name);
404  return trySetStructBody(type, {}, isPacked, parser, kwLoc);
405  }
406 
407  // Parse subtypes. For identified structs, put the identifier of the struct on
408  // the stack to support self-references in the recursive calls.
409  SmallVector<Type, 4> subtypes;
410  SMLoc subtypesLoc = parser.getCurrentLocation();
411  do {
412  if (isIdentified)
413  knownStructNames.insert(name);
414  Type type;
415  if (dispatchParse(parser, type))
416  return LLVMStructType();
417  subtypes.push_back(type);
418  if (isIdentified)
419  knownStructNames.pop_back();
420  } while (succeeded(parser.parseOptionalComma()));
421 
422  if (parser.parseRParen() || parser.parseGreater())
423  return LLVMStructType();
424 
425  // Construct the struct with body.
426  if (!isIdentified)
427  return LLVMStructType::getLiteralChecked(
428  [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
429  auto type = LLVMStructType::getIdentifiedChecked(
430  [loc] { return emitError(loc); }, loc.getContext(), name);
431  return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
432 }
433 
434 /// Parses a type appearing inside another LLVM dialect-compatible type. This
435 /// will try to parse any type in full form (including types with the `!llvm`
436 /// prefix), and on failure fall back to parsing the short-hand version of the
437 /// LLVM dialect types without the `!llvm` prefix.
438 static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
439  SMLoc keyLoc = parser.getCurrentLocation();
440 
441  // Try parsing any MLIR type.
442  Type type;
443  OptionalParseResult result = parser.parseOptionalType(type);
444  if (result.has_value()) {
445  if (failed(result.value()))
446  return nullptr;
447  if (!allowAny) {
448  parser.emitError(keyLoc) << "unexpected type, expected keyword";
449  return nullptr;
450  }
451  return type;
452  }
453 
454  // If no type found, fallback to the shorthand form.
455  StringRef key;
456  if (failed(parser.parseKeyword(&key)))
457  return Type();
458 
459  MLIRContext *ctx = parser.getContext();
460  return StringSwitch<function_ref<Type()>>(key)
461  .Case("void", [&] { return LLVMVoidType::get(ctx); })
462  .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
463  .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
464  .Case("token", [&] { return LLVMTokenType::get(ctx); })
465  .Case("label", [&] { return LLVMLabelType::get(ctx); })
466  .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
467  .Case("func", [&] { return parseFunctionType(parser); })
468  .Case("ptr", [&] { return parsePointerType(parser); })
469  .Case("vec", [&] { return parseVectorType(parser); })
470  .Case("array", [&] { return parseArrayType(parser); })
471  .Case("struct", [&] { return parseStructType(parser); })
472  .Default([&] {
473  parser.emitError(keyLoc) << "unknown LLVM type: " << key;
474  return Type();
475  })();
476 }
477 
478 /// Helper to use in parse lists.
479 static ParseResult dispatchParse(AsmParser &parser, Type &type) {
480  type = dispatchParse(parser);
481  return success(type != nullptr);
482 }
483 
484 /// Parses one of the LLVM dialect types.
485 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
486  SMLoc loc = parser.getCurrentLocation();
487  Type type = dispatchParse(parser, /*allowAny=*/false);
488  if (!type)
489  return type;
490  if (!isCompatibleOuterType(type)) {
491  parser.emitError(loc) << "unexpected type, expected keyword";
492  return nullptr;
493  }
494  return type;
495 }
Include the generated interface declarations.
StringRef getName()
Returns the name of an identified struct.
Definition: LLVMTypes.cpp:465
virtual void printType(Type type)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:854
static void printStructType(AsmPrinter &printer, LLVMStructType type)
Prints a structure type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static void dispatchPrint(AsmPrinter &printer, Type type)
If the given type is compatible with the LLVM dialect, prints it using internal functions to avoid ge...
static StringRef getTypeKeyword(Type type)
Returns the keyword to use for the given type.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:283
bool isIdentified() const
Checks if a struct is identified.
Definition: LLVMTypes.cpp:459
This base class exposes generic asm printer hooks, usable across the various derived printers...
bool isa() const
Definition: Types.h:254