MLIR  14.0.0git
LLVMTypeSyntax.cpp
Go to the documentation of this file.
1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22 
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
26 static void dispatchPrint(AsmPrinter &printer, Type type) {
27  if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28  return mlir::LLVM::detail::printType(type, printer);
29  printer.printType(type);
30 }
31 
32 /// Returns the keyword to use for the given type.
33 static StringRef getTypeKeyword(Type type) {
34  return TypeSwitch<Type, StringRef>(type)
35  .Case<LLVMVoidType>([&](Type) { return "void"; })
36  .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37  .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38  .Case<LLVMTokenType>([&](Type) { return "token"; })
39  .Case<LLVMLabelType>([&](Type) { return "label"; })
40  .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41  .Case<LLVMFunctionType>([&](Type) { return "func"; })
42  .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43  .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44  [&](Type) { return "vec"; })
45  .Case<LLVMArrayType>([&](Type) { return "array"; })
46  .Case<LLVMStructType>([&](Type) { return "struct"; })
47  .Default([](Type) -> StringRef {
48  llvm_unreachable("unexpected 'llvm' type kind");
49  });
50 }
51 
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
54 static void printStructType(AsmPrinter &printer, LLVMStructType type) {
55  // This keeps track of the names of identified structure types that are
56  // currently being printed. Since such types can refer themselves, this
57  // tracking is necessary to stop the recursion: the current function may be
58  // called recursively from AsmPrinter::printType after the appropriate
59  // dispatch. We maintain the invariant of this storage being modified
60  // exclusively in this function, and at most one name being added per call.
61  // TODO: consider having such functionality inside AsmPrinter.
62  thread_local SetVector<StringRef> knownStructNames;
63  unsigned stackSize = knownStructNames.size();
64  (void)stackSize;
65  auto guard = llvm::make_scope_exit([&]() {
66  assert(knownStructNames.size() == stackSize &&
67  "malformed identified stack when printing recursive structs");
68  });
69 
70  printer << "<";
71  if (type.isIdentified()) {
72  printer << '"' << type.getName() << '"';
73  // If we are printing a reference to one of the enclosing structs, just
74  // print the name and stop to avoid infinitely long output.
75  if (knownStructNames.count(type.getName())) {
76  printer << '>';
77  return;
78  }
79  printer << ", ";
80  }
81 
82  if (type.isIdentified() && type.isOpaque()) {
83  printer << "opaque>";
84  return;
85  }
86 
87  if (type.isPacked())
88  printer << "packed ";
89 
90  // Put the current type on stack to avoid infinite recursion.
91  printer << '(';
92  if (type.isIdentified())
93  knownStructNames.insert(type.getName());
94  llvm::interleaveComma(type.getBody(), printer.getStream(),
95  [&](Type subtype) { dispatchPrint(printer, subtype); });
96  if (type.isIdentified())
97  knownStructNames.pop_back();
98  printer << ')';
99  printer << '>';
100 }
101 
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
104 static void printArrayOrVectorType(AsmPrinter &printer, TypeTy type) {
105  printer << '<' << type.getNumElements() << " x ";
106  dispatchPrint(printer, type.getElementType());
107  printer << '>';
108 }
109 
110 /// Prints a function type.
111 static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
112  printer << '<';
113  dispatchPrint(printer, funcType.getReturnType());
114  printer << " (";
115  llvm::interleaveComma(
116  funcType.getParams(), printer.getStream(),
117  [&printer](Type subtype) { dispatchPrint(printer, subtype); });
118  if (funcType.isVarArg()) {
119  if (funcType.getNumParams() != 0)
120  printer << ", ";
121  printer << "...";
122  }
123  printer << ")>";
124 }
125 
126 /// Prints the given LLVM dialect type recursively. This leverages closedness of
127 /// the LLVM dialect type system to avoid printing the dialect prefix
128 /// repeatedly. For recursive structures, only prints the name of the structure
129 /// when printing a self-reference. Note that this does not apply to sibling
130 /// references. For example,
131 /// struct<"a", (ptr<struct<"a">>)>
132 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
133 /// ptr<struct<"b", (ptr<struct<"c">>)>>)>
134 /// note that "b" is printed twice.
135 void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
136  if (!type) {
137  printer << "<<NULL-TYPE>>";
138  return;
139  }
140 
141  printer << getTypeKeyword(type);
142 
143  if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
144  printer << '<';
145  dispatchPrint(printer, ptrType.getElementType());
146  if (ptrType.getAddressSpace() != 0)
147  printer << ", " << ptrType.getAddressSpace();
148  printer << '>';
149  return;
150  }
151 
152  if (auto arrayType = type.dyn_cast<LLVMArrayType>())
153  return printArrayOrVectorType(printer, arrayType);
154  if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
155  return printArrayOrVectorType(printer, vectorType);
156 
157  if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
158  printer << "<? x " << vectorType.getMinNumElements() << " x ";
159  dispatchPrint(printer, vectorType.getElementType());
160  printer << '>';
161  return;
162  }
163 
164  if (auto structType = type.dyn_cast<LLVMStructType>())
165  return printStructType(printer, structType);
166 
167  if (auto funcType = type.dyn_cast<LLVMFunctionType>())
168  return printFunctionType(printer, funcType);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Parsing.
173 //===----------------------------------------------------------------------===//
174 
175 static ParseResult dispatchParse(AsmParser &parser, Type &type);
176 
177 /// Parses an LLVM dialect function type.
178 /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
179 static LLVMFunctionType parseFunctionType(AsmParser &parser) {
180  llvm::SMLoc loc = parser.getCurrentLocation();
181  Type returnType;
182  if (parser.parseLess() || dispatchParse(parser, returnType) ||
183  parser.parseLParen())
184  return LLVMFunctionType();
185 
186  // Function type without arguments.
187  if (succeeded(parser.parseOptionalRParen())) {
188  if (succeeded(parser.parseGreater()))
189  return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
190  /*isVarArg=*/false);
191  return LLVMFunctionType();
192  }
193 
194  // Parse arguments.
195  SmallVector<Type, 8> argTypes;
196  do {
197  if (succeeded(parser.parseOptionalEllipsis())) {
198  if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
199  return LLVMFunctionType();
200  return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
201  /*isVarArg=*/true);
202  }
203 
204  Type arg;
205  if (dispatchParse(parser, arg))
206  return LLVMFunctionType();
207  argTypes.push_back(arg);
208  } while (succeeded(parser.parseOptionalComma()));
209 
210  if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
211  return LLVMFunctionType();
212  return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
213  /*isVarArg=*/false);
214 }
215 
216 /// Parses an LLVM dialect pointer type.
217 /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
218 static LLVMPointerType parsePointerType(AsmParser &parser) {
219  llvm::SMLoc loc = parser.getCurrentLocation();
220  Type elementType;
221  if (parser.parseLess() || dispatchParse(parser, elementType))
222  return LLVMPointerType();
223 
224  unsigned addressSpace = 0;
225  if (succeeded(parser.parseOptionalComma()) &&
226  failed(parser.parseInteger(addressSpace)))
227  return LLVMPointerType();
228  if (failed(parser.parseGreater()))
229  return LLVMPointerType();
230  return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
231 }
232 
233 /// Parses an LLVM dialect vector type.
234 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
235 /// Supports both fixed and scalable vectors.
236 static Type parseVectorType(AsmParser &parser) {
237  SmallVector<int64_t, 2> dims;
238  llvm::SMLoc dimPos, typePos;
239  Type elementType;
240  llvm::SMLoc loc = parser.getCurrentLocation();
241  if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
242  parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
243  parser.getCurrentLocation(&typePos) ||
244  dispatchParse(parser, elementType) || parser.parseGreater())
245  return Type();
246 
247  // We parsed a generic dimension list, but vectors only support two forms:
248  // - single non-dynamic entry in the list (fixed vector);
249  // - two elements, the first dynamic (indicated by -1) and the second
250  // non-dynamic (scalable vector).
251  if (dims.empty() || dims.size() > 2 ||
252  ((dims.size() == 2) ^ (dims[0] == -1)) ||
253  (dims.size() == 2 && dims[1] == -1)) {
254  parser.emitError(dimPos)
255  << "expected '? x <integer> x <type>' or '<integer> x <type>'";
256  return Type();
257  }
258 
259  bool isScalable = dims.size() == 2;
260  if (isScalable)
261  return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
262  if (elementType.isSignlessIntOrFloat()) {
263  parser.emitError(typePos)
264  << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
265  return Type();
266  }
267  return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
268 }
269 
270 /// Parses an LLVM dialect array type.
271 /// llvm-type ::= `array<` integer `x` llvm-type `>`
272 static LLVMArrayType parseArrayType(AsmParser &parser) {
273  SmallVector<int64_t, 1> dims;
274  llvm::SMLoc sizePos;
275  Type elementType;
276  llvm::SMLoc loc = parser.getCurrentLocation();
277  if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
278  parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
279  dispatchParse(parser, elementType) || parser.parseGreater())
280  return LLVMArrayType();
281 
282  if (dims.size() != 1) {
283  parser.emitError(sizePos) << "expected ? x <type>";
284  return LLVMArrayType();
285  }
286 
287  return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
288 }
289 
290 /// Attempts to set the body of an identified structure type. Reports a parsing
291 /// error at `subtypesLoc` in case of failure.
292 static LLVMStructType trySetStructBody(LLVMStructType type,
293  ArrayRef<Type> subtypes, bool isPacked,
294  AsmParser &parser,
295  llvm::SMLoc subtypesLoc) {
296  for (Type t : subtypes) {
297  if (!LLVMStructType::isValidElementType(t)) {
298  parser.emitError(subtypesLoc)
299  << "invalid LLVM structure element type: " << t;
300  return LLVMStructType();
301  }
302  }
303 
304  if (succeeded(type.setBody(subtypes, isPacked)))
305  return type;
306 
307  parser.emitError(subtypesLoc)
308  << "identified type already used with a different body";
309  return LLVMStructType();
310 }
311 
312 /// Parses an LLVM dialect structure type.
313 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`?
314 /// `(` llvm-type-list `)` `>`
315 /// | `struct<` string-literal `>`
316 /// | `struct<` string-literal `, opaque>`
317 static LLVMStructType parseStructType(AsmParser &parser) {
318  // This keeps track of the names of identified structure types that are
319  // currently being parsed. Since such types can refer themselves, this
320  // tracking is necessary to stop the recursion: the current function may be
321  // called recursively from AsmParser::parseType after the appropriate
322  // dispatch. We maintain the invariant of this storage being modified
323  // exclusively in this function, and at most one name being added per call.
324  // TODO: consider having such functionality inside AsmParser.
325  thread_local SetVector<StringRef> knownStructNames;
326  unsigned stackSize = knownStructNames.size();
327  (void)stackSize;
328  auto guard = llvm::make_scope_exit([&]() {
329  assert(knownStructNames.size() == stackSize &&
330  "malformed identified stack when parsing recursive structs");
331  });
332 
333  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
334 
335  if (failed(parser.parseLess()))
336  return LLVMStructType();
337 
338  // If we are parsing a self-reference to a recursive struct, i.e. the parsing
339  // stack already contains a struct with the same identifier, bail out after
340  // the name.
341  std::string name;
342  bool isIdentified = succeeded(parser.parseOptionalString(&name));
343  if (isIdentified) {
344  if (knownStructNames.count(name)) {
345  if (failed(parser.parseGreater()))
346  return LLVMStructType();
347  return LLVMStructType::getIdentifiedChecked(
348  [loc] { return emitError(loc); }, loc.getContext(), name);
349  }
350  if (failed(parser.parseComma()))
351  return LLVMStructType();
352  }
353 
354  // Handle intentionally opaque structs.
355  llvm::SMLoc kwLoc = parser.getCurrentLocation();
356  if (succeeded(parser.parseOptionalKeyword("opaque"))) {
357  if (!isIdentified)
358  return parser.emitError(kwLoc, "only identified structs can be opaque"),
359  LLVMStructType();
360  if (failed(parser.parseGreater()))
361  return LLVMStructType();
362  auto type = LLVMStructType::getOpaqueChecked(
363  [loc] { return emitError(loc); }, loc.getContext(), name);
364  if (!type.isOpaque()) {
365  parser.emitError(kwLoc, "redeclaring defined struct as opaque");
366  return LLVMStructType();
367  }
368  return type;
369  }
370 
371  // Check for packedness.
372  bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
373  if (failed(parser.parseLParen()))
374  return LLVMStructType();
375 
376  // Fast pass for structs with zero subtypes.
377  if (succeeded(parser.parseOptionalRParen())) {
378  if (failed(parser.parseGreater()))
379  return LLVMStructType();
380  if (!isIdentified)
381  return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
382  loc.getContext(), {}, isPacked);
383  auto type = LLVMStructType::getIdentifiedChecked(
384  [loc] { return emitError(loc); }, loc.getContext(), name);
385  return trySetStructBody(type, {}, isPacked, parser, kwLoc);
386  }
387 
388  // Parse subtypes. For identified structs, put the identifier of the struct on
389  // the stack to support self-references in the recursive calls.
390  SmallVector<Type, 4> subtypes;
391  llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
392  do {
393  if (isIdentified)
394  knownStructNames.insert(name);
395  Type type;
396  if (dispatchParse(parser, type))
397  return LLVMStructType();
398  subtypes.push_back(type);
399  if (isIdentified)
400  knownStructNames.pop_back();
401  } while (succeeded(parser.parseOptionalComma()));
402 
403  if (parser.parseRParen() || parser.parseGreater())
404  return LLVMStructType();
405 
406  // Construct the struct with body.
407  if (!isIdentified)
408  return LLVMStructType::getLiteralChecked(
409  [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
410  auto type = LLVMStructType::getIdentifiedChecked(
411  [loc] { return emitError(loc); }, loc.getContext(), name);
412  return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
413 }
414 
415 /// Parses a type appearing inside another LLVM dialect-compatible type. This
416 /// will try to parse any type in full form (including types with the `!llvm`
417 /// prefix), and on failure fall back to parsing the short-hand version of the
418 /// LLVM dialect types without the `!llvm` prefix.
419 static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
420  llvm::SMLoc keyLoc = parser.getCurrentLocation();
421 
422  // Try parsing any MLIR type.
423  Type type;
424  OptionalParseResult result = parser.parseOptionalType(type);
425  if (result.hasValue()) {
426  if (failed(result.getValue()))
427  return nullptr;
428  if (!allowAny) {
429  parser.emitError(keyLoc) << "unexpected type, expected keyword";
430  return nullptr;
431  }
432  return type;
433  }
434 
435  // If no type found, fallback to the shorthand form.
436  StringRef key;
437  if (failed(parser.parseKeyword(&key)))
438  return Type();
439 
440  MLIRContext *ctx = parser.getContext();
441  return StringSwitch<function_ref<Type()>>(key)
442  .Case("void", [&] { return LLVMVoidType::get(ctx); })
443  .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
444  .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
445  .Case("token", [&] { return LLVMTokenType::get(ctx); })
446  .Case("label", [&] { return LLVMLabelType::get(ctx); })
447  .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
448  .Case("func", [&] { return parseFunctionType(parser); })
449  .Case("ptr", [&] { return parsePointerType(parser); })
450  .Case("vec", [&] { return parseVectorType(parser); })
451  .Case("array", [&] { return parseArrayType(parser); })
452  .Case("struct", [&] { return parseStructType(parser); })
453  .Default([&] {
454  parser.emitError(keyLoc) << "unknown LLVM type: " << key;
455  return Type();
456  })();
457 }
458 
459 /// Helper to use in parse lists.
460 static ParseResult dispatchParse(AsmParser &parser, Type &type) {
461  type = dispatchParse(parser);
462  return success(type != nullptr);
463 }
464 
465 /// Parses one of the LLVM dialect types.
466 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
467  llvm::SMLoc loc = parser.getCurrentLocation();
468  Type type = dispatchParse(parser, /*allowAny=*/false);
469  if (!type)
470  return type;
471  if (!isCompatibleOuterType(type)) {
472  parser.emitError(loc) << "unexpected type, expected keyword";
473  return nullptr;
474  }
475  return type;
476 }
Include the generated interface declarations.
StringRef getName()
Returns the name of an identified struct.
Definition: LLVMTypes.cpp:407
virtual void printType(Type type)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:752
static void printStructType(AsmPrinter &printer, LLVMStructType type)
Prints a structure type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static void dispatchPrint(AsmPrinter &printer, Type type)
If the given type is compatible with the LLVM dialect, prints it using internal functions to avoid ge...
static StringRef getTypeKeyword(Type type)
Returns the keyword to use for the given type.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:252
bool isIdentified() const
Checks if a struct is identified.
Definition: LLVMTypes.cpp:401
This base class exposes generic asm printer hooks, usable across the various derived printers...
bool isa() const
Definition: Types.h:234