MLIR  14.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 using namespace mlir::spirv;
35 
36 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // InlinerInterface
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
43 static inline bool containsReturn(Region &region) {
44  return llvm::any_of(region, [](Block &block) {
45  Operation *terminator = block.getTerminator();
46  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47  });
48 }
49 
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
54 
55  /// All call operations within SPIRV can be inlined.
56  bool isLegalToInline(Operation *call, Operation *callable,
57  bool wouldBeCloned) const final {
58  return true;
59  }
60 
61  /// Returns true if the given region 'src' can be inlined into the region
62  /// 'dest' that is attached to an operation registered to the current dialect.
63  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64  BlockAndValueMapping &) const final {
65  // Return true here when inlining into spv.func, spv.mlir.selection, and
66  // spv.mlir.loop operations.
67  auto *op = dest->getParentOp();
68  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69  }
70 
71  /// Returns true if the given operation 'op', that is registered to this
72  /// dialect, can be inlined into the region 'dest' that is attached to an
73  /// operation registered to the current dialect.
74  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75  BlockAndValueMapping &) const final {
76  // TODO: Enable inlining structured control flows with return.
77  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78  containsReturn(op->getRegion(0)))
79  return false;
80  // TODO: we need to filter OpKill here to avoid inlining it to
81  // a loop continue construct:
82  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83  // However OpKill is fragment shader specific and we don't support it yet.
84  return true;
85  }
86 
87  /// Handle the given inlined terminator by replacing it with a new operation
88  /// as necessary.
89  void handleTerminator(Operation *op, Block *newDest) const final {
90  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
91  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
92  op->erase();
93  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
94  llvm_unreachable("unimplemented spv.ReturnValue in inliner");
95  }
96  }
97 
98  /// Handle the given inlined terminator by replacing it with a new operation
99  /// as necessary.
100  void handleTerminator(Operation *op,
101  ArrayRef<Value> valuesToRepl) const final {
102  // Only spv.ReturnValue needs to be handled here.
103  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
104  if (!retValOp)
105  return;
106 
107  // Replace the values directly with the return operands.
108  assert(valuesToRepl.size() == 1 &&
109  "spv.ReturnValue expected to only handle one result");
110  valuesToRepl.front().replaceAllUsesWith(retValOp.value());
111  }
112 };
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // SPIR-V Dialect
117 //===----------------------------------------------------------------------===//
118 
119 void SPIRVDialect::initialize() {
120  registerAttributes();
121  registerTypes();
122 
123  // Add SPIR-V ops.
124  addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127  >();
128 
129  addInterfaces<SPIRVInlinerInterface>();
130 
131  // Allow unknown operations because SPIR-V is extensible.
132  allowUnknownOperations();
133 }
134 
135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
142 
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149  DialectAsmParser &parser);
150 
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153  DialectAsmParser &parser);
154 
155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156  DialectAsmParser &parser) {
157  Type type;
158  llvm::SMLoc typeLoc = parser.getCurrentLocation();
159  if (parser.parseType(type))
160  return Type();
161 
162  // Allow SPIR-V dialect types
163  if (&type.getDialect() == &dialect)
164  return type;
165 
166  // Check other allowed types
167  if (auto t = type.dyn_cast<FloatType>()) {
168  if (type.isBF16()) {
169  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170  return Type();
171  }
172  } else if (auto t = type.dyn_cast<IntegerType>()) {
173  if (!ScalarType::isValid(t)) {
174  parser.emitError(typeLoc,
175  "only 1/8/16/32/64-bit integer type allowed but found ")
176  << type;
177  return Type();
178  }
179  } else if (auto t = type.dyn_cast<VectorType>()) {
180  if (t.getRank() != 1) {
181  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182  return Type();
183  }
184  if (t.getNumElements() > 4) {
185  parser.emitError(
186  typeLoc, "vector length has to be less than or equal to 4 but found ")
187  << t.getNumElements();
188  return Type();
189  }
190  } else {
191  parser.emitError(typeLoc, "cannot use ")
192  << type << " to compose SPIR-V types";
193  return Type();
194  }
195 
196  return type;
197 }
198 
199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200  DialectAsmParser &parser) {
201  Type type;
202  llvm::SMLoc typeLoc = parser.getCurrentLocation();
203  if (parser.parseType(type))
204  return Type();
205 
206  if (auto t = type.dyn_cast<VectorType>()) {
207  if (t.getRank() != 1) {
208  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209  return Type();
210  }
211  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212  parser.emitError(typeLoc,
213  "matrix columns size has to be less than or equal "
214  "to 4 and greater than or equal 2, but found ")
215  << t.getNumElements();
216  return Type();
217  }
218 
219  if (!t.getElementType().isa<FloatType>()) {
220  parser.emitError(typeLoc, "matrix columns' elements must be of "
221  "Float type, got ")
222  << t.getElementType();
223  return Type();
224  }
225  } else {
226  parser.emitError(typeLoc, "matrix must be composed using vector "
227  "type, got ")
228  << type;
229  return Type();
230  }
231 
232  return type;
233 }
234 
235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
236  DialectAsmParser &parser) {
237  Type type;
238  llvm::SMLoc typeLoc = parser.getCurrentLocation();
239  if (parser.parseType(type))
240  return Type();
241 
242  if (!type.isa<ImageType>()) {
243  parser.emitError(typeLoc,
244  "sampled image must be composed using image type, got ")
245  << type;
246  return Type();
247  }
248 
249  return type;
250 }
251 
252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
254 /// missing.
255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
256  DialectAsmParser &parser,
257  unsigned &stride) {
258  if (failed(parser.parseOptionalComma())) {
259  stride = 0;
260  return success();
261  }
262 
263  if (parser.parseKeyword("stride") || parser.parseEqual())
264  return failure();
265 
266  llvm::SMLoc strideLoc = parser.getCurrentLocation();
267  Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
268  if (!optStride)
269  return failure();
270 
271  if (!(stride = optStride.getValue())) {
272  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
273  return failure();
274  }
275  return success();
276 }
277 
278 // element-type ::= integer-type
279 // | floating-point-type
280 // | vector-type
281 // | spirv-type
282 //
283 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
284 // (`,` `stride` `=` integer-literal)? `>`
285 static Type parseArrayType(SPIRVDialect const &dialect,
286  DialectAsmParser &parser) {
287  if (parser.parseLess())
288  return Type();
289 
290  SmallVector<int64_t, 1> countDims;
291  llvm::SMLoc countLoc = parser.getCurrentLocation();
292  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
293  return Type();
294  if (countDims.size() != 1) {
295  parser.emitError(countLoc,
296  "expected single integer for array element count");
297  return Type();
298  }
299 
300  // According to the SPIR-V spec:
301  // "Length is the number of elements in the array. It must be at least 1."
302  int64_t count = countDims[0];
303  if (count == 0) {
304  parser.emitError(countLoc, "expected array length greater than 0");
305  return Type();
306  }
307 
308  Type elementType = parseAndVerifyType(dialect, parser);
309  if (!elementType)
310  return Type();
311 
312  unsigned stride = 0;
313  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
314  return Type();
315 
316  if (parser.parseGreater())
317  return Type();
318  return ArrayType::get(elementType, count, stride);
319 }
320 
321 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
322 // rows ',' columns>`
323 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
324  DialectAsmParser &parser) {
325  if (parser.parseLess())
326  return Type();
327 
329  llvm::SMLoc countLoc = parser.getCurrentLocation();
330  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
331  return Type();
332 
333  if (dims.size() != 2) {
334  parser.emitError(countLoc, "expected rows and columns size");
335  return Type();
336  }
337 
338  auto elementTy = parseAndVerifyType(dialect, parser);
339  if (!elementTy)
340  return Type();
341 
342  Scope scope;
343  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
344  return Type();
345 
346  if (parser.parseGreater())
347  return Type();
348  return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
349 }
350 
351 // TODO: Reorder methods to be utilities first and parse*Type
352 // methods in alphabetical order
353 //
354 // storage-class ::= `UniformConstant`
355 // | `Uniform`
356 // | `Workgroup`
357 // | <and other storage classes...>
358 //
359 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
360 static Type parsePointerType(SPIRVDialect const &dialect,
361  DialectAsmParser &parser) {
362  if (parser.parseLess())
363  return Type();
364 
365  auto pointeeType = parseAndVerifyType(dialect, parser);
366  if (!pointeeType)
367  return Type();
368 
369  StringRef storageClassSpec;
370  llvm::SMLoc storageClassLoc = parser.getCurrentLocation();
371  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
372  return Type();
373 
374  auto storageClass = symbolizeStorageClass(storageClassSpec);
375  if (!storageClass) {
376  parser.emitError(storageClassLoc, "unknown storage class: ")
377  << storageClassSpec;
378  return Type();
379  }
380  if (parser.parseGreater())
381  return Type();
382  return PointerType::get(pointeeType, *storageClass);
383 }
384 
385 // runtime-array-type ::= `!spv.rtarray` `<` element-type
386 // (`,` `stride` `=` integer-literal)? `>`
387 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
388  DialectAsmParser &parser) {
389  if (parser.parseLess())
390  return Type();
391 
392  Type elementType = parseAndVerifyType(dialect, parser);
393  if (!elementType)
394  return Type();
395 
396  unsigned stride = 0;
397  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
398  return Type();
399 
400  if (parser.parseGreater())
401  return Type();
402  return RuntimeArrayType::get(elementType, stride);
403 }
404 
405 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
406 static Type parseMatrixType(SPIRVDialect const &dialect,
407  DialectAsmParser &parser) {
408  if (parser.parseLess())
409  return Type();
410 
411  SmallVector<int64_t, 1> countDims;
412  llvm::SMLoc countLoc = parser.getCurrentLocation();
413  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
414  return Type();
415  if (countDims.size() != 1) {
416  parser.emitError(countLoc, "expected single unsigned "
417  "integer for number of columns");
418  return Type();
419  }
420 
421  int64_t columnCount = countDims[0];
422  // According to the specification, Matrices can have 2, 3, or 4 columns
423  if (columnCount < 2 || columnCount > 4) {
424  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
425  "columns");
426  return Type();
427  }
428 
429  Type columnType = parseAndVerifyMatrixType(dialect, parser);
430  if (!columnType)
431  return Type();
432 
433  if (parser.parseGreater())
434  return Type();
435 
436  return MatrixType::get(columnType, columnCount);
437 }
438 
439 // Specialize this function to parse each of the parameters that define an
440 // ImageType. By default it assumes this is an enum type.
441 template <typename ValTy>
442 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
443  DialectAsmParser &parser) {
444  StringRef enumSpec;
445  llvm::SMLoc enumLoc = parser.getCurrentLocation();
446  if (parser.parseKeyword(&enumSpec)) {
447  return llvm::None;
448  }
449 
450  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
451  if (!val)
452  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
453  return val;
454 }
455 
456 template <>
457 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
458  DialectAsmParser &parser) {
459  // TODO: Further verify that the element type can be sampled
460  auto ty = parseAndVerifyType(dialect, parser);
461  if (!ty)
462  return llvm::None;
463  return ty;
464 }
465 
466 template <typename IntTy>
467 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
468  DialectAsmParser &parser) {
469  IntTy offsetVal = std::numeric_limits<IntTy>::max();
470  if (parser.parseInteger(offsetVal))
471  return llvm::None;
472  return offsetVal;
473 }
474 
475 template <>
476 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
477  DialectAsmParser &parser) {
478  return parseAndVerifyInteger<unsigned>(dialect, parser);
479 }
480 
481 namespace {
482 // Functor object to parse a comma separated list of specs. The function
483 // parseAndVerify does the actual parsing and verification of individual
484 // elements. This is a functor since parsing the last element of the list
485 // (termination condition) needs partial specialization.
486 template <typename ParseType, typename... Args> struct ParseCommaSeparatedList {
487  Optional<std::tuple<ParseType, Args...>>
488  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
489  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
490  if (!parseVal)
491  return llvm::None;
492 
493  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
494  if (numArgs != 0 && failed(parser.parseComma()))
495  return llvm::None;
496  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
497  if (!remainingValues)
498  return llvm::None;
499  return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
500  remainingValues.getValue());
501  }
502 };
503 
504 // Partial specialization of the function to parse a comma separated list of
505 // specs to parse the last element of the list.
506 template <typename ParseType> struct ParseCommaSeparatedList<ParseType> {
507  Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
508  DialectAsmParser &parser) const {
509  if (auto value = parseAndVerify<ParseType>(dialect, parser))
510  return std::tuple<ParseType>(value.getValue());
511  return llvm::None;
512  }
513 };
514 } // namespace
515 
516 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
517 //
518 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
519 //
520 // arrayed-info ::= `NonArrayed` | `Arrayed`
521 //
522 // sampling-info ::= `SingleSampled` | `MultiSampled`
523 //
524 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
525 //
526 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
527 //
528 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
529 // arrayed-info `,` sampling-info `,`
530 // sampler-use-info `,` format `>`
531 static Type parseImageType(SPIRVDialect const &dialect,
532  DialectAsmParser &parser) {
533  if (parser.parseLess())
534  return Type();
535 
536  auto value =
537  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
538  ImageSamplingInfo, ImageSamplerUseInfo,
539  ImageFormat>{}(dialect, parser);
540  if (!value)
541  return Type();
542 
543  if (parser.parseGreater())
544  return Type();
545  return ImageType::get(value.getValue());
546 }
547 
548 // sampledImage-type :: = `!spv.sampledImage<` image-type `>`
549 static Type parseSampledImageType(SPIRVDialect const &dialect,
550  DialectAsmParser &parser) {
551  if (parser.parseLess())
552  return Type();
553 
554  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
555  if (!parsedType)
556  return Type();
557 
558  if (parser.parseGreater())
559  return Type();
560  return SampledImageType::get(parsedType);
561 }
562 
563 // Parse decorations associated with a member.
565  SPIRVDialect const &dialect, DialectAsmParser &parser,
566  ArrayRef<Type> memberTypes,
569 
570  // Check if the first element is offset.
571  llvm::SMLoc offsetLoc = parser.getCurrentLocation();
572  StructType::OffsetInfo offset = 0;
573  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
574  if (offsetParseResult.hasValue()) {
575  if (failed(*offsetParseResult))
576  return failure();
577 
578  if (offsetInfo.size() != memberTypes.size() - 1) {
579  return parser.emitError(offsetLoc,
580  "offset specification must be given for "
581  "all members");
582  }
583  offsetInfo.push_back(offset);
584  }
585 
586  // Check for no spirv::Decorations.
587  if (succeeded(parser.parseOptionalRSquare()))
588  return success();
589 
590  // If there was an offset, make sure to parse the comma.
591  if (offsetParseResult.hasValue() && parser.parseComma())
592  return failure();
593 
594  // Check for spirv::Decorations.
595  do {
596  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
597  if (!memberDecoration)
598  return failure();
599 
600  // Parse member decoration value if it exists.
601  if (succeeded(parser.parseOptionalEqual())) {
602  auto memberDecorationValue =
603  parseAndVerifyInteger<uint32_t>(dialect, parser);
604 
605  if (!memberDecorationValue)
606  return failure();
607 
608  memberDecorationInfo.emplace_back(
609  static_cast<uint32_t>(memberTypes.size() - 1), 1,
610  memberDecoration.getValue(), memberDecorationValue.getValue());
611  } else {
612  memberDecorationInfo.emplace_back(
613  static_cast<uint32_t>(memberTypes.size() - 1), 0,
614  memberDecoration.getValue(), 0);
615  }
616 
617  } while (succeeded(parser.parseOptionalComma()));
618 
619  return parser.parseRSquare();
620 }
621 
622 // struct-member-decoration ::= integer-literal? spirv-decoration*
623 // struct-type ::=
624 // `!spv.struct<` (id `,`)?
625 // `(`
626 // (spirv-type (`[` struct-member-decoration `]`)?)*
627 // `)>`
628 static Type parseStructType(SPIRVDialect const &dialect,
629  DialectAsmParser &parser) {
630  // TODO: This function is quite lengthy. Break it down into smaller chunks.
631 
632  // To properly resolve recursive references while parsing recursive struct
633  // types, we need to maintain a list of enclosing struct type names. This set
634  // maintains the names of struct types in which the type we are about to parse
635  // is nested.
636  //
637  // Note: This has to be thread_local to enable multiple threads to safely
638  // parse concurrently.
639  thread_local SetVector<StringRef> structContext;
640 
641  static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
642  StringRef identifier) {
643  if (!identifier.empty())
644  structContext.remove(identifier);
645 
646  return Type();
647  };
648 
649  if (parser.parseLess())
650  return Type();
651 
652  StringRef identifier;
653 
654  // Check if this is an identified struct type.
655  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
656  // Check if this is a possible recursive reference.
657  if (succeeded(parser.parseOptionalGreater())) {
658  if (structContext.count(identifier) == 0) {
659  parser.emitError(
660  parser.getNameLoc(),
661  "recursive struct reference not nested in struct definition");
662 
663  return Type();
664  }
665 
666  return StructType::getIdentified(dialect.getContext(), identifier);
667  }
668 
669  if (failed(parser.parseComma()))
670  return Type();
671 
672  if (structContext.count(identifier) != 0) {
673  parser.emitError(parser.getNameLoc(),
674  "identifier already used for an enclosing struct");
675 
676  return removeIdentifierAndFail(structContext, identifier);
677  }
678 
679  structContext.insert(identifier);
680  }
681 
682  if (failed(parser.parseLParen()))
683  return removeIdentifierAndFail(structContext, identifier);
684 
685  if (succeeded(parser.parseOptionalRParen()) &&
686  succeeded(parser.parseOptionalGreater())) {
687  if (!identifier.empty())
688  structContext.remove(identifier);
689 
690  return StructType::getEmpty(dialect.getContext(), identifier);
691  }
692 
693  StructType idStructTy;
694 
695  if (!identifier.empty())
696  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
697 
698  SmallVector<Type, 4> memberTypes;
701 
702  do {
703  Type memberType;
704  if (parser.parseType(memberType))
705  return removeIdentifierAndFail(structContext, identifier);
706  memberTypes.push_back(memberType);
707 
708  if (succeeded(parser.parseOptionalLSquare()))
709  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
710  memberDecorationInfo))
711  return removeIdentifierAndFail(structContext, identifier);
712  } while (succeeded(parser.parseOptionalComma()));
713 
714  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
715  parser.emitError(parser.getNameLoc(),
716  "offset specification must be given for all members");
717  return removeIdentifierAndFail(structContext, identifier);
718  }
719 
720  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
721  return removeIdentifierAndFail(structContext, identifier);
722 
723  if (!identifier.empty()) {
724  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
725  memberDecorationInfo)))
726  return Type();
727 
728  structContext.remove(identifier);
729  return idStructTy;
730  }
731 
732  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
733 }
734 
735 // spirv-type ::= array-type
736 // | element-type
737 // | image-type
738 // | pointer-type
739 // | runtime-array-type
740 // | sampled-image-type
741 // | struct-type
743  StringRef keyword;
744  if (parser.parseKeyword(&keyword))
745  return Type();
746 
747  if (keyword == "array")
748  return parseArrayType(*this, parser);
749  if (keyword == "coopmatrix")
750  return parseCooperativeMatrixType(*this, parser);
751  if (keyword == "image")
752  return parseImageType(*this, parser);
753  if (keyword == "ptr")
754  return parsePointerType(*this, parser);
755  if (keyword == "rtarray")
756  return parseRuntimeArrayType(*this, parser);
757  if (keyword == "sampled_image")
758  return parseSampledImageType(*this, parser);
759  if (keyword == "struct")
760  return parseStructType(*this, parser);
761  if (keyword == "matrix")
762  return parseMatrixType(*this, parser);
763  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
764  return Type();
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // Type Printing
769 //===----------------------------------------------------------------------===//
770 
771 static void print(ArrayType type, DialectAsmPrinter &os) {
772  os << "array<" << type.getNumElements() << " x " << type.getElementType();
773  if (unsigned stride = type.getArrayStride())
774  os << ", stride=" << stride;
775  os << ">";
776 }
777 
778 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
779  os << "rtarray<" << type.getElementType();
780  if (unsigned stride = type.getArrayStride())
781  os << ", stride=" << stride;
782  os << ">";
783 }
784 
785 static void print(PointerType type, DialectAsmPrinter &os) {
786  os << "ptr<" << type.getPointeeType() << ", "
787  << stringifyStorageClass(type.getStorageClass()) << ">";
788 }
789 
790 static void print(ImageType type, DialectAsmPrinter &os) {
791  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
792  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
793  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
794  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
795  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
796  << stringifyImageFormat(type.getImageFormat()) << ">";
797 }
798 
799 static void print(SampledImageType type, DialectAsmPrinter &os) {
800  os << "sampled_image<" << type.getImageType() << ">";
801 }
802 
803 static void print(StructType type, DialectAsmPrinter &os) {
804  thread_local SetVector<StringRef> structContext;
805 
806  os << "struct<";
807 
808  if (type.isIdentified()) {
809  os << type.getIdentifier();
810 
811  if (structContext.count(type.getIdentifier())) {
812  os << ">";
813  return;
814  }
815 
816  os << ", ";
817  structContext.insert(type.getIdentifier());
818  }
819 
820  os << "(";
821 
822  auto printMember = [&](unsigned i) {
823  os << type.getElementType(i);
825  type.getMemberDecorations(i, decorations);
826  if (type.hasOffset() || !decorations.empty()) {
827  os << " [";
828  if (type.hasOffset()) {
829  os << type.getMemberOffset(i);
830  if (!decorations.empty())
831  os << ", ";
832  }
833  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
834  os << stringifyDecoration(decoration.decoration);
835  if (decoration.hasValue) {
836  os << "=" << decoration.decorationValue;
837  }
838  };
839  llvm::interleaveComma(decorations, os, eachFn);
840  os << "]";
841  }
842  };
843  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
844  printMember);
845  os << ")>";
846 
847  if (type.isIdentified())
848  structContext.remove(type.getIdentifier());
849 }
850 
852  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
853  os << type.getElementType() << ", " << stringifyScope(type.getScope());
854  os << ">";
855 }
856 
857 static void print(MatrixType type, DialectAsmPrinter &os) {
858  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
859  os << ">";
860 }
861 
862 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
863  TypeSwitch<Type>(type)
866  [&](auto type) { print(type, os); })
867  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // Attribute Parsing
872 //===----------------------------------------------------------------------===//
873 
874 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
875 /// of the parsed keyword, and returns failure if any error occurs.
877  DialectAsmParser &parser,
878  function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) {
879  if (parser.parseLSquare())
880  return failure();
881 
882  // Special case for empty list.
883  if (succeeded(parser.parseOptionalRSquare()))
884  return success();
885 
886  // Keep parsing the keyword and an optional comma following it. If the comma
887  // is successfully parsed, then we have more keywords to parse.
888  do {
889  auto loc = parser.getCurrentLocation();
890  StringRef keyword;
891  if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
892  return failure();
893  } while (succeeded(parser.parseOptionalComma()));
894 
895  if (parser.parseRSquare())
896  return failure();
897 
898  return success();
899 }
900 
901 /// Parses a spirv::InterfaceVarABIAttr.
903  if (parser.parseLess())
904  return {};
905 
906  Builder &builder = parser.getBuilder();
907 
908  if (parser.parseLParen())
909  return {};
910 
911  IntegerAttr descriptorSetAttr;
912  {
913  auto loc = parser.getCurrentLocation();
914  uint32_t descriptorSet = 0;
915  auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
916 
917  if (!descriptorSetParseResult.hasValue() ||
918  failed(*descriptorSetParseResult)) {
919  parser.emitError(loc, "missing descriptor set");
920  return {};
921  }
922  descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
923  }
924 
925  if (parser.parseComma())
926  return {};
927 
928  IntegerAttr bindingAttr;
929  {
930  auto loc = parser.getCurrentLocation();
931  uint32_t binding = 0;
932  auto bindingParseResult = parser.parseOptionalInteger(binding);
933 
934  if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
935  parser.emitError(loc, "missing binding");
936  return {};
937  }
938  bindingAttr = builder.getI32IntegerAttr(binding);
939  }
940 
941  if (parser.parseRParen())
942  return {};
943 
944  IntegerAttr storageClassAttr;
945  {
946  if (succeeded(parser.parseOptionalComma())) {
947  auto loc = parser.getCurrentLocation();
948  StringRef storageClass;
949  if (parser.parseKeyword(&storageClass))
950  return {};
951 
952  if (auto storageClassSymbol =
953  spirv::symbolizeStorageClass(storageClass)) {
954  storageClassAttr = builder.getI32IntegerAttr(
955  static_cast<uint32_t>(*storageClassSymbol));
956  } else {
957  parser.emitError(loc, "unknown storage class: ") << storageClass;
958  return {};
959  }
960  }
961  }
962 
963  if (parser.parseGreater())
964  return {};
965 
966  return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
967  storageClassAttr);
968 }
969 
971  if (parser.parseLess())
972  return {};
973 
974  Builder &builder = parser.getBuilder();
975 
976  IntegerAttr versionAttr;
977  {
978  auto loc = parser.getCurrentLocation();
979  StringRef version;
980  if (parser.parseKeyword(&version) || parser.parseComma())
981  return {};
982 
983  if (auto versionSymbol = spirv::symbolizeVersion(version)) {
984  versionAttr =
985  builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
986  } else {
987  parser.emitError(loc, "unknown version: ") << version;
988  return {};
989  }
990  }
991 
992  ArrayAttr capabilitiesAttr;
993  {
994  SmallVector<Attribute, 4> capabilities;
995  llvm::SMLoc errorloc;
996  StringRef errorKeyword;
997 
998  auto processCapability = [&](llvm::SMLoc loc, StringRef capability) {
999  if (auto capSymbol = spirv::symbolizeCapability(capability)) {
1000  capabilities.push_back(
1001  builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
1002  return success();
1003  }
1004  return errorloc = loc, errorKeyword = capability, failure();
1005  };
1006  if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
1007  if (!errorKeyword.empty())
1008  parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
1009  return {};
1010  }
1011 
1012  capabilitiesAttr = builder.getArrayAttr(capabilities);
1013  }
1014 
1015  ArrayAttr extensionsAttr;
1016  {
1017  SmallVector<Attribute, 1> extensions;
1018  llvm::SMLoc errorloc;
1019  StringRef errorKeyword;
1020 
1021  auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
1022  if (spirv::symbolizeExtension(extension)) {
1023  extensions.push_back(builder.getStringAttr(extension));
1024  return success();
1025  }
1026  return errorloc = loc, errorKeyword = extension, failure();
1027  };
1028  if (parseKeywordList(parser, processExtension)) {
1029  if (!errorKeyword.empty())
1030  parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
1031  return {};
1032  }
1033 
1034  extensionsAttr = builder.getArrayAttr(extensions);
1035  }
1036 
1037  if (parser.parseGreater())
1038  return {};
1039 
1040  return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
1041  extensionsAttr);
1042 }
1043 
1044 /// Parses a spirv::TargetEnvAttr.
1046  if (parser.parseLess())
1047  return {};
1048 
1049  spirv::VerCapExtAttr tripleAttr;
1050  if (parser.parseAttribute(tripleAttr) || parser.parseComma())
1051  return {};
1052 
1053  // Parse [vendor[:device-type[:device-id]]]
1054  Vendor vendorID = Vendor::Unknown;
1055  DeviceType deviceType = DeviceType::Unknown;
1056  uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
1057  {
1058  auto loc = parser.getCurrentLocation();
1059  StringRef vendorStr;
1060  if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
1061  if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
1062  vendorID = *vendorSymbol;
1063  } else {
1064  parser.emitError(loc, "unknown vendor: ") << vendorStr;
1065  }
1066 
1067  if (succeeded(parser.parseOptionalColon())) {
1068  loc = parser.getCurrentLocation();
1069  StringRef deviceTypeStr;
1070  if (parser.parseKeyword(&deviceTypeStr))
1071  return {};
1072  if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
1073  deviceType = *deviceTypeSymbol;
1074  } else {
1075  parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
1076  }
1077 
1078  if (succeeded(parser.parseOptionalColon())) {
1079  loc = parser.getCurrentLocation();
1080  if (parser.parseInteger(deviceID))
1081  return {};
1082  }
1083  }
1084  if (parser.parseComma())
1085  return {};
1086  }
1087  }
1088 
1089  DictionaryAttr limitsAttr;
1090  {
1091  auto loc = parser.getCurrentLocation();
1092  if (parser.parseAttribute(limitsAttr))
1093  return {};
1094 
1095  if (!limitsAttr.isa<spirv::ResourceLimitsAttr>()) {
1096  parser.emitError(
1097  loc,
1098  "limits must be a dictionary attribute containing two 32-bit integer "
1099  "attributes 'max_compute_workgroup_invocations' and "
1100  "'max_compute_workgroup_size'");
1101  return {};
1102  }
1103  }
1104 
1105  if (parser.parseGreater())
1106  return {};
1107 
1108  return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
1109  limitsAttr);
1110 }
1111 
1113  Type type) const {
1114  // SPIR-V attributes are dictionaries so they do not have type.
1115  if (type) {
1116  parser.emitError(parser.getNameLoc(), "unexpected type");
1117  return {};
1118  }
1119 
1120  // Parse the kind keyword first.
1121  StringRef attrKind;
1122  if (parser.parseKeyword(&attrKind))
1123  return {};
1124 
1125  if (attrKind == spirv::TargetEnvAttr::getKindName())
1126  return parseTargetEnvAttr(parser);
1127  if (attrKind == spirv::VerCapExtAttr::getKindName())
1128  return parseVerCapExtAttr(parser);
1129  if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
1130  return parseInterfaceVarABIAttr(parser);
1131 
1132  parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
1133  << attrKind;
1134  return {};
1135 }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // Attribute Printing
1139 //===----------------------------------------------------------------------===//
1140 
1141 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
1142  auto &os = printer.getStream();
1143  printer << spirv::VerCapExtAttr::getKindName() << "<"
1144  << spirv::stringifyVersion(triple.getVersion()) << ", [";
1145  llvm::interleaveComma(
1146  triple.getCapabilities(), os,
1147  [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
1148  printer << "], [";
1149  llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
1150  os << attr.cast<StringAttr>().getValue();
1151  });
1152  printer << "]>";
1153 }
1154 
1155 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
1156  printer << spirv::TargetEnvAttr::getKindName() << "<#spv.";
1157  print(targetEnv.getTripleAttr(), printer);
1158  spirv::Vendor vendorID = targetEnv.getVendorID();
1159  spirv::DeviceType deviceType = targetEnv.getDeviceType();
1160  uint32_t deviceID = targetEnv.getDeviceID();
1161  if (vendorID != spirv::Vendor::Unknown) {
1162  printer << ", " << spirv::stringifyVendor(vendorID);
1163  if (deviceType != spirv::DeviceType::Unknown) {
1164  printer << ":" << spirv::stringifyDeviceType(deviceType);
1166  printer << ":" << deviceID;
1167  }
1168  }
1169  printer << ", " << targetEnv.getResourceLimits() << ">";
1170 }
1171 
1172 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
1173  DialectAsmPrinter &printer) {
1174  printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
1175  << interfaceVarABIAttr.getDescriptorSet() << ", "
1176  << interfaceVarABIAttr.getBinding() << ")";
1177  auto storageClass = interfaceVarABIAttr.getStorageClass();
1178  if (storageClass)
1179  printer << ", " << spirv::stringifyStorageClass(*storageClass);
1180  printer << ">";
1181 }
1182 
1183 void SPIRVDialect::printAttribute(Attribute attr,
1184  DialectAsmPrinter &printer) const {
1185  if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
1186  print(targetEnv, printer);
1187  else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
1188  print(vceAttr, printer);
1189  else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
1190  print(interfaceVarABIAttr, printer);
1191  else
1192  llvm_unreachable("unhandled SPIR-V attribute kind");
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // Constant
1197 //===----------------------------------------------------------------------===//
1198 
1200  Attribute value, Type type,
1201  Location loc) {
1202  if (!spirv::ConstantOp::isBuildableWith(type))
1203  return nullptr;
1204 
1205  return builder.create<spirv::ConstantOp>(loc, type, value);
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // Shader Interface ABI
1210 //===----------------------------------------------------------------------===//
1211 
1212 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1213  NamedAttribute attribute) {
1214  StringRef symbol = attribute.getName().strref();
1215  Attribute attr = attribute.getValue();
1216 
1217  // TODO: figure out a way to generate the description from the
1218  // StructAttr definition.
1219  if (symbol == spirv::getEntryPointABIAttrName()) {
1220  if (!attr.isa<spirv::EntryPointABIAttr>())
1221  return op->emitError("'")
1222  << symbol
1223  << "' attribute must be a dictionary attribute containing one "
1224  "32-bit integer elements attribute: 'local_size'";
1225  } else if (symbol == spirv::getTargetEnvAttrName()) {
1226  if (!attr.isa<spirv::TargetEnvAttr>())
1227  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1228  } else {
1229  return op->emitError("found unsupported '")
1230  << symbol << "' attribute on operation";
1231  }
1232 
1233  return success();
1234 }
1235 
1236 /// Verifies the given SPIR-V `attribute` attached to a value of the given
1237 /// `valueType` is valid.
1239  NamedAttribute attribute) {
1240  StringRef symbol = attribute.getName().strref();
1241  Attribute attr = attribute.getValue();
1242 
1243  if (symbol != spirv::getInterfaceVarABIAttrName())
1244  return emitError(loc, "found unsupported '")
1245  << symbol << "' attribute on region argument";
1246 
1247  auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
1248  if (!varABIAttr)
1249  return emitError(loc, "'")
1250  << symbol << "' must be a spirv::InterfaceVarABIAttr";
1251 
1252  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1253  return emitError(loc, "'") << symbol
1254  << "' attribute cannot specify storage class "
1255  "when attaching to a non-scalar value";
1256 
1257  return success();
1258 }
1259 
1260 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1261  unsigned regionIndex,
1262  unsigned argIndex,
1263  NamedAttribute attribute) {
1264  return verifyRegionAttribute(
1265  op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
1266  attribute);
1267 }
1268 
1269 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1270  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1271  NamedAttribute attribute) {
1272  return op->emitError("cannot attach SPIR-V attributes to region result");
1273 }
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
virtual ParseResult parseLParen()=0
Parse a ( token.
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:349
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:151
Optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:64
static StringRef getKindName()
Returns the attribute kind&#39;s name (without the &#39;spv.&#39; prefix).
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
Block represents an ordered list of Operations.
Definition: Block.h:29
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
An attribute that specifies the information regarding the interface variable: descriptor set...
static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID, DeviceType deviceType, uint32_t deviceId, DictionaryAttr limits)
Gets a TargetEnvAttr instance.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
bool isa() const
Definition: Attributes.h:107
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
bool isIdentified() const
Returns true if the StructType is identified.
Definition: SPIRVTypes.cpp:987
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
ArrayAttr getExtensionsAttr()
Returns the extensions as a string array attribute.
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
Type getElementType() const
Definition: SPIRVTypes.cpp:331
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:491
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
static Optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual ParseResult parseComma()=0
Parse a , token.
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:948
DeviceType getDeviceType() const
Returns the device type.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
virtual ParseResult parseLSquare()=0
Parse a [ token.
static Attribute parseTargetEnvAttr(DialectAsmParser &parser)
Parses a spirv::TargetEnvAttr.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:337
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:719
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static Optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:222
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser)
Parses a spirv::InterfaceVarABIAttr.
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
U dyn_cast() const
Definition: Types.h:244
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:345
Type getElementType() const
Definition: SPIRVTypes.cpp:62
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:44
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
virtual ParseResult parseRParen()=0
Parse a ) token.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:41
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
static StringRef getKindName()
Returns the attribute kind&#39;s name (without the &#39;spv.&#39; prefix).
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
uint32_t getDescriptorSet()
Returns descriptor set.
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:335
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:985
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
Optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
uint64_t getMemberOffset(unsigned) const
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:971
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
Vendor getVendorID() const
Returns the vendor ID.
uint32_t getBinding()
Returns binding.
virtual ParseResult parseRSquare()=0
Parse a ] token.
static void print(ArrayType type, DialectAsmPrinter &os)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
uint32_t getDeviceID() const
Returns the device ID.
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:989
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true)=0
Parse a &#39;x&#39; separated dimension list.
cap_range getCapabilities()
Returns the capabilities.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
virtual ParseResult parseOptionalGreater()=0
Parse a &#39;>&#39; token if present.
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:991
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:458
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
Type getType() const
Return the type of this value.
Definition: Value.h:117
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
U dyn_cast() const
Definition: Attributes.h:117
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:163
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, Optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual ParseResult parseType(Type &result)=0
Parse a type.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid...
SPIR-V struct type.
Definition: SPIRVTypes.h:278
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:397
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
VerCapExtAttr getTripleAttr() const
Returns the (version, capabilities, extensions) triple attribute.
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
static ParseResult parseKeywordList(DialectAsmParser &parser, function_ref< LogicalResult(llvm::SMLoc, StringRef)> processKeyword)
Parses a comma-separated list of keywords, invokes processKeyword on each of the parsed keyword...
static StringRef getKindName()
Returns the attribute kind&#39;s name (without the &#39;spv.&#39; prefix).
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
virtual ParseResult parseEqual()=0
Parse a = token.
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:232
unsigned getNumColumns() const
Returns the number of columns.
static MatrixType get(Type columnType, uint32_t columnCount)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:961
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
bool isa() const
Definition: Types.h:234
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This class helps build Operations.
Definition: Builders.h:177
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
Type getColumnType() const
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:341
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:429
static bool containsReturn(Region &region)
Returns true if the given region contains spv.Return or spv.ReturnValue ops.
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
Optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
Version getVersion()
Returns the version.
bool isBF16() const
Definition: Types.cpp:21
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
static Attribute parseVerCapExtAttr(DialectAsmParser &parser)