MLIR  22.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 
32 using namespace mlir;
33 using namespace mlir::spirv;
34 
35 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
36 
37 //===----------------------------------------------------------------------===//
38 // InlinerInterface
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
42 /// ops.
43 static inline bool containsReturn(Region &region) {
44  return llvm::any_of(region, [](Block &block) {
45  Operation *terminator = block.getTerminator();
46  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47  });
48 }
49 
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
54 
55  /// All call operations within SPIRV can be inlined.
56  bool isLegalToInline(Operation *call, Operation *callable,
57  bool wouldBeCloned) const final {
58  return true;
59  }
60 
61  /// Returns true if the given region 'src' can be inlined into the region
62  /// 'dest' that is attached to an operation registered to the current dialect.
63  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64  IRMapping &) const final {
65  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
66  // spirv.mlir.loop operations.
67  auto *op = dest->getParentOp();
68  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69  }
70 
71  /// Returns true if the given operation 'op', that is registered to this
72  /// dialect, can be inlined into the region 'dest' that is attached to an
73  /// operation registered to the current dialect.
74  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75  IRMapping &) const final {
76  // TODO: Enable inlining structured control flows with return.
77  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78  containsReturn(op->getRegion(0)))
79  return false;
80  // TODO: we need to filter OpKill here to avoid inlining it to
81  // a loop continue construct:
82  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83  // For now, we just disallow inlining OpKill anywhere in the code,
84  // but this restriction should be relaxed, as pointed above.
85  if (isa<spirv::KillOp>(op))
86  return false;
87 
88  return true;
89  }
90 
91  /// Handle the given inlined terminator by replacing it with a new operation
92  /// as necessary.
93  void handleTerminator(Operation *op, Block *newDest) const final {
94  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95  auto builder = OpBuilder(op);
96  spirv::BranchOp::create(builder, op->getLoc(), newDest);
97  op->erase();
98  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
99  auto builder = OpBuilder(op);
100  spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
101  retValOp->getOperands());
102  op->erase();
103  }
104  }
105 
106  /// Handle the given inlined terminator by replacing it with a new operation
107  /// as necessary.
108  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
109  // Only spirv.ReturnValue needs to be handled here.
110  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
111  if (!retValOp)
112  return;
113 
114  // Replace the values directly with the return operands.
115  assert(valuesToRepl.size() == 1 &&
116  "spirv.ReturnValue expected to only handle one result");
117  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
118  }
119 };
120 } // namespace
121 
122 //===----------------------------------------------------------------------===//
123 // SPIR-V Dialect
124 //===----------------------------------------------------------------------===//
125 
126 void SPIRVDialect::initialize() {
127  registerAttributes();
128  registerTypes();
129 
130  // Add SPIR-V ops.
131  addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
134  >();
135 
136  addInterfaces<SPIRVInlinerInterface>();
137 
138  // Allow unknown operations because SPIR-V is extensible.
139  allowUnknownOperations();
140  declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
141 }
142 
143 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
144  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Type Parsing
149 //===----------------------------------------------------------------------===//
150 
151 // Forward declarations.
152 template <typename ValTy>
153 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
154  DialectAsmParser &parser);
155 template <>
156 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
157  DialectAsmParser &parser);
158 
159 template <>
160 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
161  DialectAsmParser &parser);
162 
163 static Type parseAndVerifyType(SPIRVDialect const &dialect,
164  DialectAsmParser &parser) {
165  Type type;
166  SMLoc typeLoc = parser.getCurrentLocation();
167  if (parser.parseType(type))
168  return Type();
169 
170  // Allow SPIR-V dialect types
171  if (&type.getDialect() == &dialect)
172  return type;
173 
174  // Check other allowed types
175  if (auto t = llvm::dyn_cast<FloatType>(type)) {
176  // TODO: All float types are allowed for now, but this should be fixed.
177  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
178  if (!ScalarType::isValid(t)) {
179  parser.emitError(typeLoc,
180  "only 1/8/16/32/64-bit integer type allowed but found ")
181  << type;
182  return Type();
183  }
184  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
185  if (t.getRank() != 1) {
186  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
187  return Type();
188  }
189  if (t.getNumElements() > 4) {
190  parser.emitError(
191  typeLoc, "vector length has to be less than or equal to 4 but found ")
192  << t.getNumElements();
193  return Type();
194  }
195  } else if (auto t = dyn_cast<TensorArmType>(type)) {
196  if (!isa<ScalarType>(t.getElementType())) {
197  parser.emitError(
198  typeLoc, "only scalar element type allowed in tensor type but found ")
199  << t.getElementType();
200  return Type();
201  }
202  } else {
203  parser.emitError(typeLoc, "cannot use ")
204  << type << " to compose SPIR-V types";
205  return Type();
206  }
207 
208  return type;
209 }
210 
211 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
212  DialectAsmParser &parser) {
213  Type type;
214  SMLoc typeLoc = parser.getCurrentLocation();
215  if (parser.parseType(type))
216  return Type();
217 
218  if (auto t = llvm::dyn_cast<VectorType>(type)) {
219  if (t.getRank() != 1) {
220  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
221  return Type();
222  }
223  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
224  parser.emitError(typeLoc,
225  "matrix columns size has to be less than or equal "
226  "to 4 and greater than or equal 2, but found ")
227  << t.getNumElements();
228  return Type();
229  }
230 
231  if (!llvm::isa<FloatType>(t.getElementType())) {
232  parser.emitError(typeLoc, "matrix columns' elements must be of "
233  "Float type, got ")
234  << t.getElementType();
235  return Type();
236  }
237  } else {
238  parser.emitError(typeLoc, "matrix must be composed using vector "
239  "type, got ")
240  << type;
241  return Type();
242  }
243 
244  return type;
245 }
246 
247 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
248  DialectAsmParser &parser) {
249  Type type;
250  SMLoc typeLoc = parser.getCurrentLocation();
251  if (parser.parseType(type))
252  return Type();
253 
254  auto imageType = dyn_cast<ImageType>(type);
255  if (!imageType) {
256  parser.emitError(typeLoc,
257  "sampled image must be composed using image type, got ")
258  << type;
259  return Type();
260  }
261 
262  if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, imageType.getDim())) {
263  parser.emitError(
264  typeLoc, "sampled image Dim must not be SubpassData or Buffer, got ")
265  << stringifyDim(imageType.getDim());
266  return Type();
267  }
268 
269  return type;
270 }
271 
272 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
273 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
274 /// missing.
275 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
276  DialectAsmParser &parser,
277  unsigned &stride) {
278  if (failed(parser.parseOptionalComma())) {
279  stride = 0;
280  return success();
281  }
282 
283  if (parser.parseKeyword("stride") || parser.parseEqual())
284  return failure();
285 
286  SMLoc strideLoc = parser.getCurrentLocation();
287  std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
288  if (!optStride)
289  return failure();
290 
291  if (!(stride = *optStride)) {
292  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
293  return failure();
294  }
295  return success();
296 }
297 
298 // element-type ::= integer-type
299 // | floating-point-type
300 // | vector-type
301 // | spirv-type
302 //
303 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
304 // (`,` `stride` `=` integer-literal)? `>`
305 static Type parseArrayType(SPIRVDialect const &dialect,
306  DialectAsmParser &parser) {
307  if (parser.parseLess())
308  return Type();
309 
310  SmallVector<int64_t, 1> countDims;
311  SMLoc countLoc = parser.getCurrentLocation();
312  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
313  return Type();
314  if (countDims.size() != 1) {
315  parser.emitError(countLoc,
316  "expected single integer for array element count");
317  return Type();
318  }
319 
320  // According to the SPIR-V spec:
321  // "Length is the number of elements in the array. It must be at least 1."
322  int64_t count = countDims[0];
323  if (count == 0) {
324  parser.emitError(countLoc, "expected array length greater than 0");
325  return Type();
326  }
327 
328  Type elementType = parseAndVerifyType(dialect, parser);
329  if (!elementType)
330  return Type();
331 
332  unsigned stride = 0;
333  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
334  return Type();
335 
336  if (parser.parseGreater())
337  return Type();
338  return ArrayType::get(elementType, count, stride);
339 }
340 
341 // cooperative-matrix-type ::=
342 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
343 // scope `,` use `>`
344 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
345  DialectAsmParser &parser) {
346  if (parser.parseLess())
347  return {};
348 
350  SMLoc countLoc = parser.getCurrentLocation();
351  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
352  return {};
353 
354  if (dims.size() != 2) {
355  parser.emitError(countLoc, "expected row and column count");
356  return {};
357  }
358 
359  auto elementTy = parseAndVerifyType(dialect, parser);
360  if (!elementTy)
361  return {};
362 
363  Scope scope;
364  if (parser.parseComma() ||
365  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
366  return {};
367 
368  CooperativeMatrixUseKHR use;
369  if (parser.parseComma() ||
370  spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
371  return {};
372 
373  if (parser.parseGreater())
374  return {};
375 
376  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
377 }
378 
379 // tensor-arm-type ::=
380 // `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
381 static Type parseTensorArmType(SPIRVDialect const &dialect,
382  DialectAsmParser &parser) {
383  if (parser.parseLess())
384  return {};
385 
386  bool unranked = false;
388  SMLoc countLoc = parser.getCurrentLocation();
389 
390  if (parser.parseOptionalStar().succeeded()) {
391  unranked = true;
392  if (parser.parseXInDimensionList())
393  return {};
394  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
395  return {};
396  }
397 
398  if (!unranked && dims.empty()) {
399  parser.emitError(countLoc, "arm.tensors do not support rank zero");
400  return {};
401  }
402 
403  if (llvm::is_contained(dims, 0)) {
404  parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
405  return {};
406  }
407 
408  if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
409  llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
410  parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
411  "fully dynamic or completed shaped");
412  return {};
413  }
414 
415  auto elementTy = parseAndVerifyType(dialect, parser);
416  if (!elementTy)
417  return {};
418 
419  if (parser.parseGreater())
420  return {};
421 
422  return TensorArmType::get(dims, elementTy);
423 }
424 
425 // TODO: Reorder methods to be utilities first and parse*Type
426 // methods in alphabetical order
427 //
428 // storage-class ::= `UniformConstant`
429 // | `Uniform`
430 // | `Workgroup`
431 // | <and other storage classes...>
432 //
433 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
434 static Type parsePointerType(SPIRVDialect const &dialect,
435  DialectAsmParser &parser) {
436  if (parser.parseLess())
437  return Type();
438 
439  auto pointeeType = parseAndVerifyType(dialect, parser);
440  if (!pointeeType)
441  return Type();
442 
443  StringRef storageClassSpec;
444  SMLoc storageClassLoc = parser.getCurrentLocation();
445  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
446  return Type();
447 
448  auto storageClass = symbolizeStorageClass(storageClassSpec);
449  if (!storageClass) {
450  parser.emitError(storageClassLoc, "unknown storage class: ")
451  << storageClassSpec;
452  return Type();
453  }
454  if (parser.parseGreater())
455  return Type();
456  return PointerType::get(pointeeType, *storageClass);
457 }
458 
459 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
460 // (`,` `stride` `=` integer-literal)? `>`
461 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
462  DialectAsmParser &parser) {
463  if (parser.parseLess())
464  return Type();
465 
466  Type elementType = parseAndVerifyType(dialect, parser);
467  if (!elementType)
468  return Type();
469 
470  unsigned stride = 0;
471  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
472  return Type();
473 
474  if (parser.parseGreater())
475  return Type();
476  return RuntimeArrayType::get(elementType, stride);
477 }
478 
479 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
480 static Type parseMatrixType(SPIRVDialect const &dialect,
481  DialectAsmParser &parser) {
482  if (parser.parseLess())
483  return Type();
484 
485  SmallVector<int64_t, 1> countDims;
486  SMLoc countLoc = parser.getCurrentLocation();
487  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
488  return Type();
489  if (countDims.size() != 1) {
490  parser.emitError(countLoc, "expected single unsigned "
491  "integer for number of columns");
492  return Type();
493  }
494 
495  int64_t columnCount = countDims[0];
496  // According to the specification, Matrices can have 2, 3, or 4 columns
497  if (columnCount < 2 || columnCount > 4) {
498  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
499  "columns");
500  return Type();
501  }
502 
503  Type columnType = parseAndVerifyMatrixType(dialect, parser);
504  if (!columnType)
505  return Type();
506 
507  if (parser.parseGreater())
508  return Type();
509 
510  return MatrixType::get(columnType, columnCount);
511 }
512 
513 // Specialize this function to parse each of the parameters that define an
514 // ImageType. By default it assumes this is an enum type.
515 template <typename ValTy>
516 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
517  DialectAsmParser &parser) {
518  StringRef enumSpec;
519  SMLoc enumLoc = parser.getCurrentLocation();
520  if (parser.parseKeyword(&enumSpec)) {
521  return std::nullopt;
522  }
523 
524  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
525  if (!val)
526  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
527  return val;
528 }
529 
530 template <>
531 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
532  DialectAsmParser &parser) {
533  // TODO: Further verify that the element type can be sampled
534  auto ty = parseAndVerifyType(dialect, parser);
535  if (!ty)
536  return std::nullopt;
537  return ty;
538 }
539 
540 template <typename IntTy>
541 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
542  DialectAsmParser &parser) {
543  IntTy offsetVal = std::numeric_limits<IntTy>::max();
544  if (parser.parseInteger(offsetVal))
545  return std::nullopt;
546  return offsetVal;
547 }
548 
549 template <>
550 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
551  DialectAsmParser &parser) {
552  return parseAndVerifyInteger<unsigned>(dialect, parser);
553 }
554 
555 namespace {
556 // Functor object to parse a comma separated list of specs. The function
557 // parseAndVerify does the actual parsing and verification of individual
558 // elements. This is a functor since parsing the last element of the list
559 // (termination condition) needs partial specialization.
560 template <typename ParseType, typename... Args>
561 struct ParseCommaSeparatedList {
562  std::optional<std::tuple<ParseType, Args...>>
563  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
564  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
565  if (!parseVal)
566  return std::nullopt;
567 
568  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
569  if (numArgs != 0 && failed(parser.parseComma()))
570  return std::nullopt;
571  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
572  if (!remainingValues)
573  return std::nullopt;
574  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
575  remainingValues.value());
576  }
577 };
578 
579 // Partial specialization of the function to parse a comma separated list of
580 // specs to parse the last element of the list.
581 template <typename ParseType>
582 struct ParseCommaSeparatedList<ParseType> {
583  std::optional<std::tuple<ParseType>>
584  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
585  if (auto value = parseAndVerify<ParseType>(dialect, parser))
586  return std::tuple<ParseType>(*value);
587  return std::nullopt;
588  }
589 };
590 } // namespace
591 
592 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
593 //
594 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
595 //
596 // arrayed-info ::= `NonArrayed` | `Arrayed`
597 //
598 // sampling-info ::= `SingleSampled` | `MultiSampled`
599 //
600 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
601 //
602 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
603 //
604 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
605 // arrayed-info `,` sampling-info `,`
606 // sampler-use-info `,` format `>`
607 static Type parseImageType(SPIRVDialect const &dialect,
608  DialectAsmParser &parser) {
609  if (parser.parseLess())
610  return Type();
611 
612  auto value =
613  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
614  ImageSamplingInfo, ImageSamplerUseInfo,
615  ImageFormat>{}(dialect, parser);
616  if (!value)
617  return Type();
618 
619  if (parser.parseGreater())
620  return Type();
621  return ImageType::get(*value);
622 }
623 
624 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
625 static Type parseSampledImageType(SPIRVDialect const &dialect,
626  DialectAsmParser &parser) {
627  if (parser.parseLess())
628  return Type();
629 
630  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
631  if (!parsedType)
632  return Type();
633 
634  if (parser.parseGreater())
635  return Type();
636  return SampledImageType::get(parsedType);
637 }
638 
639 // Parse decorations associated with a member.
640 static ParseResult parseStructMemberDecorations(
641  SPIRVDialect const &dialect, DialectAsmParser &parser,
642  ArrayRef<Type> memberTypes,
645 
646  // Check if the first element is offset.
647  SMLoc offsetLoc = parser.getCurrentLocation();
648  StructType::OffsetInfo offset = 0;
649  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
650  if (offsetParseResult.has_value()) {
651  if (failed(*offsetParseResult))
652  return failure();
653 
654  if (offsetInfo.size() != memberTypes.size() - 1) {
655  return parser.emitError(offsetLoc,
656  "offset specification must be given for "
657  "all members");
658  }
659  offsetInfo.push_back(offset);
660  }
661 
662  // Check for no spirv::Decorations.
663  if (succeeded(parser.parseOptionalRSquare()))
664  return success();
665 
666  // If there was an offset, make sure to parse the comma.
667  if (offsetParseResult.has_value() && parser.parseComma())
668  return failure();
669 
670  // Check for spirv::Decorations.
671  auto parseDecorations = [&]() {
672  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
673  if (!memberDecoration)
674  return failure();
675 
676  // Parse member decoration value if it exists.
677  if (succeeded(parser.parseOptionalEqual())) {
678  Attribute memberDecorationValue;
679  if (failed(parser.parseAttribute(memberDecorationValue)))
680  return failure();
681 
682  memberDecorationInfo.emplace_back(
683  static_cast<uint32_t>(memberTypes.size() - 1),
684  memberDecoration.value(), memberDecorationValue);
685  } else {
686  memberDecorationInfo.emplace_back(
687  static_cast<uint32_t>(memberTypes.size() - 1),
688  memberDecoration.value(), UnitAttr::get(dialect.getContext()));
689  }
690  return success();
691  };
692  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
693  failed(parser.parseRSquare()))
694  return failure();
695 
696  return success();
697 }
698 
699 // struct-member-decoration ::= integer-literal? spirv-decoration*
700 // struct-type ::=
701 // `!spirv.struct<` (id `,`)?
702 // `(`
703 // (spirv-type (`[` struct-member-decoration `]`)?)*
704 // `)`
705 // (`,` struct-decoration)?
706 // `>`
707 static Type parseStructType(SPIRVDialect const &dialect,
708  DialectAsmParser &parser) {
709  // TODO: This function is quite lengthy. Break it down into smaller chunks.
710 
711  if (parser.parseLess())
712  return Type();
713 
714  StringRef identifier;
715  FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
716 
717  // Check if this is an identified struct type.
718  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
719  // Check if this is a possible recursive reference.
720  auto structType =
721  StructType::getIdentified(dialect.getContext(), identifier);
722  cyclicParse = parser.tryStartCyclicParse(structType);
723  if (succeeded(parser.parseOptionalGreater())) {
724  if (succeeded(cyclicParse)) {
725  parser.emitError(
726  parser.getNameLoc(),
727  "recursive struct reference not nested in struct definition");
728 
729  return Type();
730  }
731 
732  return structType;
733  }
734 
735  if (failed(parser.parseComma()))
736  return Type();
737 
738  if (failed(cyclicParse)) {
739  parser.emitError(parser.getNameLoc(),
740  "identifier already used for an enclosing struct");
741  return Type();
742  }
743  }
744 
745  if (failed(parser.parseLParen()))
746  return Type();
747 
748  if (succeeded(parser.parseOptionalRParen()) &&
749  succeeded(parser.parseOptionalGreater())) {
750  return StructType::getEmpty(dialect.getContext(), identifier);
751  }
752 
753  StructType idStructTy;
754 
755  if (!identifier.empty())
756  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
757 
758  SmallVector<Type, 4> memberTypes;
761 
762  do {
763  Type memberType;
764  if (parser.parseType(memberType))
765  return Type();
766  memberTypes.push_back(memberType);
767 
768  if (succeeded(parser.parseOptionalLSquare()))
769  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
770  memberDecorationInfo))
771  return Type();
772  } while (succeeded(parser.parseOptionalComma()));
773 
774  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
775  parser.emitError(parser.getNameLoc(),
776  "offset specification must be given for all members");
777  return Type();
778  }
779 
780  if (failed(parser.parseRParen()))
781  return Type();
782 
784 
785  auto parseStructDecoration = [&]() {
786  std::optional<spirv::Decoration> decoration =
787  parseAndVerify<spirv::Decoration>(dialect, parser);
788  if (!decoration)
789  return failure();
790 
791  // Parse decoration value if it exists.
792  if (succeeded(parser.parseOptionalEqual())) {
793  Attribute decorationValue;
794  if (failed(parser.parseAttribute(decorationValue)))
795  return failure();
796 
797  structDecorationInfo.emplace_back(decoration.value(), decorationValue);
798  } else {
799  structDecorationInfo.emplace_back(decoration.value(),
800  UnitAttr::get(dialect.getContext()));
801  }
802  return success();
803  };
804 
805  while (succeeded(parser.parseOptionalComma()))
806  if (failed(parseStructDecoration()))
807  return Type();
808 
809  if (failed(parser.parseGreater()))
810  return Type();
811 
812  if (!identifier.empty()) {
813  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
814  memberDecorationInfo,
815  structDecorationInfo)))
816  return Type();
817  return idStructTy;
818  }
819 
820  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
821  structDecorationInfo);
822 }
823 
824 // spirv-type ::= array-type
825 // | element-type
826 // | image-type
827 // | pointer-type
828 // | runtime-array-type
829 // | sampled-image-type
830 // | struct-type
832  StringRef keyword;
833  if (parser.parseKeyword(&keyword))
834  return Type();
835 
836  if (keyword == "array")
837  return parseArrayType(*this, parser);
838  if (keyword == "coopmatrix")
839  return parseCooperativeMatrixType(*this, parser);
840  if (keyword == "image")
841  return parseImageType(*this, parser);
842  if (keyword == "ptr")
843  return parsePointerType(*this, parser);
844  if (keyword == "rtarray")
845  return parseRuntimeArrayType(*this, parser);
846  if (keyword == "sampled_image")
847  return parseSampledImageType(*this, parser);
848  if (keyword == "struct")
849  return parseStructType(*this, parser);
850  if (keyword == "matrix")
851  return parseMatrixType(*this, parser);
852  if (keyword == "arm.tensor")
853  return parseTensorArmType(*this, parser);
854  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
855  return Type();
856 }
857 
858 //===----------------------------------------------------------------------===//
859 // Type Printing
860 //===----------------------------------------------------------------------===//
861 
862 static void print(ArrayType type, DialectAsmPrinter &os) {
863  os << "array<" << type.getNumElements() << " x " << type.getElementType();
864  if (unsigned stride = type.getArrayStride())
865  os << ", stride=" << stride;
866  os << ">";
867 }
868 
869 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
870  os << "rtarray<" << type.getElementType();
871  if (unsigned stride = type.getArrayStride())
872  os << ", stride=" << stride;
873  os << ">";
874 }
875 
876 static void print(PointerType type, DialectAsmPrinter &os) {
877  os << "ptr<" << type.getPointeeType() << ", "
878  << stringifyStorageClass(type.getStorageClass()) << ">";
879 }
880 
881 static void print(ImageType type, DialectAsmPrinter &os) {
882  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
883  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
884  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
885  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
886  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
887  << stringifyImageFormat(type.getImageFormat()) << ">";
888 }
889 
890 static void print(SampledImageType type, DialectAsmPrinter &os) {
891  os << "sampled_image<" << type.getImageType() << ">";
892 }
893 
894 static void print(StructType type, DialectAsmPrinter &os) {
895  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
896 
897  os << "struct<";
898 
899  if (type.isIdentified()) {
900  os << type.getIdentifier();
901 
902  cyclicPrint = os.tryStartCyclicPrint(type);
903  if (failed(cyclicPrint)) {
904  os << ">";
905  return;
906  }
907 
908  os << ", ";
909  }
910 
911  os << "(";
912 
913  auto printMember = [&](unsigned i) {
914  os << type.getElementType(i);
916  type.getMemberDecorations(i, decorations);
917  if (type.hasOffset() || !decorations.empty()) {
918  os << " [";
919  if (type.hasOffset()) {
920  os << type.getMemberOffset(i);
921  if (!decorations.empty())
922  os << ", ";
923  }
924  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
925  os << stringifyDecoration(decoration.decoration);
926  if (decoration.hasValue()) {
927  os << "=";
928  os.printAttributeWithoutType(decoration.decorationValue);
929  }
930  };
931  llvm::interleaveComma(decorations, os, eachFn);
932  os << "]";
933  }
934  };
935  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
936  printMember);
937  os << ")";
938 
940  type.getStructDecorations(decorations);
941  if (!decorations.empty()) {
942  os << ", ";
943  auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
944  os << stringifyDecoration(decoration.decoration);
945  if (decoration.hasValue()) {
946  os << "=";
947  os.printAttributeWithoutType(decoration.decorationValue);
948  }
949  };
950  llvm::interleaveComma(decorations, os, eachFn);
951  }
952 
953  os << ">";
954 }
955 
957  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
958  << type.getElementType() << ", " << type.getScope() << ", "
959  << type.getUse() << ">";
960 }
961 
962 static void print(MatrixType type, DialectAsmPrinter &os) {
963  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
964  os << ">";
965 }
966 
967 static void print(TensorArmType type, DialectAsmPrinter &os) {
968  os << "arm.tensor<";
969 
970  llvm::interleave(
971  type.getShape(), os,
972  [&](int64_t dim) {
973  if (ShapedType::isDynamic(dim))
974  os << '?';
975  else
976  os << dim;
977  },
978  "x");
979  if (!type.hasRank()) {
980  os << "*";
981  }
982  os << "x" << type.getElementType() << ">";
983 }
984 
985 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
986  TypeSwitch<Type>(type)
989  [&](auto type) { print(type, os); })
990  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // Constant
995 //===----------------------------------------------------------------------===//
996 
998  Attribute value, Type type,
999  Location loc) {
1000  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
1001  return ub::PoisonOp::create(builder, loc, type, poison);
1002 
1003  if (!spirv::ConstantOp::isBuildableWith(type))
1004  return nullptr;
1005 
1006  return spirv::ConstantOp::create(builder, loc, type, value);
1007 }
1008 
1009 //===----------------------------------------------------------------------===//
1010 // Shader Interface ABI
1011 //===----------------------------------------------------------------------===//
1012 
1013 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1014  NamedAttribute attribute) {
1015  StringRef symbol = attribute.getName().strref();
1016  Attribute attr = attribute.getValue();
1017 
1018  if (symbol == spirv::getEntryPointABIAttrName()) {
1019  if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
1020  return op->emitError("'")
1021  << symbol << "' attribute must be an entry point ABI attribute";
1022  }
1023  } else if (symbol == spirv::getTargetEnvAttrName()) {
1024  if (!llvm::isa<spirv::TargetEnvAttr>(attr))
1025  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1026  } else {
1027  return op->emitError("found unsupported '")
1028  << symbol << "' attribute on operation";
1029  }
1030 
1031  return success();
1032 }
1033 
1034 /// Verifies the given SPIR-V `attribute` attached to a value of the given
1035 /// `valueType` is valid.
1036 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
1037  NamedAttribute attribute) {
1038  StringRef symbol = attribute.getName().strref();
1039  Attribute attr = attribute.getValue();
1040 
1041  if (symbol == spirv::getInterfaceVarABIAttrName()) {
1042  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
1043  if (!varABIAttr)
1044  return emitError(loc, "'")
1045  << symbol << "' must be a spirv::InterfaceVarABIAttr";
1046 
1047  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1048  return emitError(loc, "'") << symbol
1049  << "' attribute cannot specify storage class "
1050  "when attaching to a non-scalar value";
1051  return success();
1052  }
1053  if (symbol == spirv::DecorationAttr::name) {
1054  if (!isa<spirv::DecorationAttr>(attr))
1055  return emitError(loc, "'")
1056  << symbol << "' must be a spirv::DecorationAttr";
1057  return success();
1058  }
1059 
1060  return emitError(loc, "found unsupported '")
1061  << symbol << "' attribute on region argument";
1062 }
1063 
1064 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1065  unsigned regionIndex,
1066  unsigned argIndex,
1067  NamedAttribute attribute) {
1068  auto funcOp = dyn_cast<FunctionOpInterface>(op);
1069  if (!funcOp)
1070  return success();
1071  Type argType = funcOp.getArgumentTypes()[argIndex];
1072 
1073  return verifyRegionAttribute(op->getLoc(), argType, attribute);
1074 }
1075 
1076 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1077  Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
1078  NamedAttribute attribute) {
1079  if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
1080  return verifyRegionAttribute(
1081  op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
1082  return op->emitError(
1083  "cannot attach SPIR-V attributes to region result which is "
1084  "not part of a spirv::GraphARMOp type");
1085 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseTensorArmType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalStar()=0
Parse a '*' token if present.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseXInDimensionList()=0
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:281
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:267
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:272
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:255
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:283
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:390
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:392
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:404
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:400
Type getElementType() const
Definition: SPIRVTypes.cpp:386
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:396
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:451
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:453
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:514
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:504
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:793
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:548
SPIR-V struct type.
Definition: SPIRVTypes.h:295
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
SPIR-V TensorARM Type.
Definition: SPIRVTypes.h:524
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.