MLIR  21.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::spirv;
38 
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48  return llvm::any_of(region, [](Block &block) {
49  Operation *terminator = block.getTerminator();
50  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51  });
52 }
53 
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
58 
59  /// All call operations within SPIRV can be inlined.
60  bool isLegalToInline(Operation *call, Operation *callable,
61  bool wouldBeCloned) const final {
62  return true;
63  }
64 
65  /// Returns true if the given region 'src' can be inlined into the region
66  /// 'dest' that is attached to an operation registered to the current dialect.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &) const final {
69  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70  // spirv.mlir.loop operations.
71  auto *op = dest->getParentOp();
72  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73  }
74 
75  /// Returns true if the given operation 'op', that is registered to this
76  /// dialect, can be inlined into the region 'dest' that is attached to an
77  /// operation registered to the current dialect.
78  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79  IRMapping &) const final {
80  // TODO: Enable inlining structured control flows with return.
81  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82  containsReturn(op->getRegion(0)))
83  return false;
84  // TODO: we need to filter OpKill here to avoid inlining it to
85  // a loop continue construct:
86  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87  // For now, we just disallow inlining OpKill anywhere in the code,
88  // but this restriction should be relaxed, as pointed above.
89  if (isa<spirv::KillOp>(op))
90  return false;
91 
92  return true;
93  }
94 
95  /// Handle the given inlined terminator by replacing it with a new operation
96  /// as necessary.
97  void handleTerminator(Operation *op, Block *newDest) const final {
98  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
99  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
100  op->erase();
101  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
102  OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
103  retValOp->getOperands());
104  op->erase();
105  }
106  }
107 
108  /// Handle the given inlined terminator by replacing it with a new operation
109  /// as necessary.
110  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
111  // Only spirv.ReturnValue needs to be handled here.
112  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
113  if (!retValOp)
114  return;
115 
116  // Replace the values directly with the return operands.
117  assert(valuesToRepl.size() == 1 &&
118  "spirv.ReturnValue expected to only handle one result");
119  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
120  }
121 };
122 } // namespace
123 
124 //===----------------------------------------------------------------------===//
125 // SPIR-V Dialect
126 //===----------------------------------------------------------------------===//
127 
128 void SPIRVDialect::initialize() {
129  registerAttributes();
130  registerTypes();
131 
132  // Add SPIR-V ops.
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
136  >();
137 
138  addInterfaces<SPIRVInlinerInterface>();
139 
140  // Allow unknown operations because SPIR-V is extensible.
141  allowUnknownOperations();
142  declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
143 }
144 
145 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
146  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // Type Parsing
151 //===----------------------------------------------------------------------===//
152 
153 // Forward declarations.
154 template <typename ValTy>
155 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
156  DialectAsmParser &parser);
157 template <>
158 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
159  DialectAsmParser &parser);
160 
161 template <>
162 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
163  DialectAsmParser &parser);
164 
165 static Type parseAndVerifyType(SPIRVDialect const &dialect,
166  DialectAsmParser &parser) {
167  Type type;
168  SMLoc typeLoc = parser.getCurrentLocation();
169  if (parser.parseType(type))
170  return Type();
171 
172  // Allow SPIR-V dialect types
173  if (&type.getDialect() == &dialect)
174  return type;
175 
176  // Check other allowed types
177  if (auto t = llvm::dyn_cast<FloatType>(type)) {
178  // TODO: All float types are allowed for now, but this should be fixed.
179  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
180  if (!ScalarType::isValid(t)) {
181  parser.emitError(typeLoc,
182  "only 1/8/16/32/64-bit integer type allowed but found ")
183  << type;
184  return Type();
185  }
186  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
187  if (t.getRank() != 1) {
188  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
189  return Type();
190  }
191  if (t.getNumElements() > 4) {
192  parser.emitError(
193  typeLoc, "vector length has to be less than or equal to 4 but found ")
194  << t.getNumElements();
195  return Type();
196  }
197  } else {
198  parser.emitError(typeLoc, "cannot use ")
199  << type << " to compose SPIR-V types";
200  return Type();
201  }
202 
203  return type;
204 }
205 
206 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
207  DialectAsmParser &parser) {
208  Type type;
209  SMLoc typeLoc = parser.getCurrentLocation();
210  if (parser.parseType(type))
211  return Type();
212 
213  if (auto t = llvm::dyn_cast<VectorType>(type)) {
214  if (t.getRank() != 1) {
215  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
216  return Type();
217  }
218  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
219  parser.emitError(typeLoc,
220  "matrix columns size has to be less than or equal "
221  "to 4 and greater than or equal 2, but found ")
222  << t.getNumElements();
223  return Type();
224  }
225 
226  if (!llvm::isa<FloatType>(t.getElementType())) {
227  parser.emitError(typeLoc, "matrix columns' elements must be of "
228  "Float type, got ")
229  << t.getElementType();
230  return Type();
231  }
232  } else {
233  parser.emitError(typeLoc, "matrix must be composed using vector "
234  "type, got ")
235  << type;
236  return Type();
237  }
238 
239  return type;
240 }
241 
242 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
243  DialectAsmParser &parser) {
244  Type type;
245  SMLoc typeLoc = parser.getCurrentLocation();
246  if (parser.parseType(type))
247  return Type();
248 
249  if (!llvm::isa<ImageType>(type)) {
250  parser.emitError(typeLoc,
251  "sampled image must be composed using image type, got ")
252  << type;
253  return Type();
254  }
255 
256  return type;
257 }
258 
259 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
260 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
261 /// missing.
262 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
263  DialectAsmParser &parser,
264  unsigned &stride) {
265  if (failed(parser.parseOptionalComma())) {
266  stride = 0;
267  return success();
268  }
269 
270  if (parser.parseKeyword("stride") || parser.parseEqual())
271  return failure();
272 
273  SMLoc strideLoc = parser.getCurrentLocation();
274  std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
275  if (!optStride)
276  return failure();
277 
278  if (!(stride = *optStride)) {
279  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
280  return failure();
281  }
282  return success();
283 }
284 
285 // element-type ::= integer-type
286 // | floating-point-type
287 // | vector-type
288 // | spirv-type
289 //
290 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
291 // (`,` `stride` `=` integer-literal)? `>`
292 static Type parseArrayType(SPIRVDialect const &dialect,
293  DialectAsmParser &parser) {
294  if (parser.parseLess())
295  return Type();
296 
297  SmallVector<int64_t, 1> countDims;
298  SMLoc countLoc = parser.getCurrentLocation();
299  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
300  return Type();
301  if (countDims.size() != 1) {
302  parser.emitError(countLoc,
303  "expected single integer for array element count");
304  return Type();
305  }
306 
307  // According to the SPIR-V spec:
308  // "Length is the number of elements in the array. It must be at least 1."
309  int64_t count = countDims[0];
310  if (count == 0) {
311  parser.emitError(countLoc, "expected array length greater than 0");
312  return Type();
313  }
314 
315  Type elementType = parseAndVerifyType(dialect, parser);
316  if (!elementType)
317  return Type();
318 
319  unsigned stride = 0;
320  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
321  return Type();
322 
323  if (parser.parseGreater())
324  return Type();
325  return ArrayType::get(elementType, count, stride);
326 }
327 
328 // cooperative-matrix-type ::=
329 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
330 // scope `,` use `>`
331 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
332  DialectAsmParser &parser) {
333  if (parser.parseLess())
334  return {};
335 
337  SMLoc countLoc = parser.getCurrentLocation();
338  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
339  return {};
340 
341  if (dims.size() != 2) {
342  parser.emitError(countLoc, "expected row and column count");
343  return {};
344  }
345 
346  auto elementTy = parseAndVerifyType(dialect, parser);
347  if (!elementTy)
348  return {};
349 
350  Scope scope;
351  if (parser.parseComma() ||
352  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
353  return {};
354 
355  CooperativeMatrixUseKHR use;
356  if (parser.parseComma() ||
357  spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
358  return {};
359 
360  if (parser.parseGreater())
361  return {};
362 
363  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
364 }
365 
366 // TODO: Reorder methods to be utilities first and parse*Type
367 // methods in alphabetical order
368 //
369 // storage-class ::= `UniformConstant`
370 // | `Uniform`
371 // | `Workgroup`
372 // | <and other storage classes...>
373 //
374 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
375 static Type parsePointerType(SPIRVDialect const &dialect,
376  DialectAsmParser &parser) {
377  if (parser.parseLess())
378  return Type();
379 
380  auto pointeeType = parseAndVerifyType(dialect, parser);
381  if (!pointeeType)
382  return Type();
383 
384  StringRef storageClassSpec;
385  SMLoc storageClassLoc = parser.getCurrentLocation();
386  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
387  return Type();
388 
389  auto storageClass = symbolizeStorageClass(storageClassSpec);
390  if (!storageClass) {
391  parser.emitError(storageClassLoc, "unknown storage class: ")
392  << storageClassSpec;
393  return Type();
394  }
395  if (parser.parseGreater())
396  return Type();
397  return PointerType::get(pointeeType, *storageClass);
398 }
399 
400 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
401 // (`,` `stride` `=` integer-literal)? `>`
402 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
403  DialectAsmParser &parser) {
404  if (parser.parseLess())
405  return Type();
406 
407  Type elementType = parseAndVerifyType(dialect, parser);
408  if (!elementType)
409  return Type();
410 
411  unsigned stride = 0;
412  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
413  return Type();
414 
415  if (parser.parseGreater())
416  return Type();
417  return RuntimeArrayType::get(elementType, stride);
418 }
419 
420 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
421 static Type parseMatrixType(SPIRVDialect const &dialect,
422  DialectAsmParser &parser) {
423  if (parser.parseLess())
424  return Type();
425 
426  SmallVector<int64_t, 1> countDims;
427  SMLoc countLoc = parser.getCurrentLocation();
428  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
429  return Type();
430  if (countDims.size() != 1) {
431  parser.emitError(countLoc, "expected single unsigned "
432  "integer for number of columns");
433  return Type();
434  }
435 
436  int64_t columnCount = countDims[0];
437  // According to the specification, Matrices can have 2, 3, or 4 columns
438  if (columnCount < 2 || columnCount > 4) {
439  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
440  "columns");
441  return Type();
442  }
443 
444  Type columnType = parseAndVerifyMatrixType(dialect, parser);
445  if (!columnType)
446  return Type();
447 
448  if (parser.parseGreater())
449  return Type();
450 
451  return MatrixType::get(columnType, columnCount);
452 }
453 
454 // Specialize this function to parse each of the parameters that define an
455 // ImageType. By default it assumes this is an enum type.
456 template <typename ValTy>
457 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
458  DialectAsmParser &parser) {
459  StringRef enumSpec;
460  SMLoc enumLoc = parser.getCurrentLocation();
461  if (parser.parseKeyword(&enumSpec)) {
462  return std::nullopt;
463  }
464 
465  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
466  if (!val)
467  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
468  return val;
469 }
470 
471 template <>
472 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
473  DialectAsmParser &parser) {
474  // TODO: Further verify that the element type can be sampled
475  auto ty = parseAndVerifyType(dialect, parser);
476  if (!ty)
477  return std::nullopt;
478  return ty;
479 }
480 
481 template <typename IntTy>
482 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
483  DialectAsmParser &parser) {
484  IntTy offsetVal = std::numeric_limits<IntTy>::max();
485  if (parser.parseInteger(offsetVal))
486  return std::nullopt;
487  return offsetVal;
488 }
489 
490 template <>
491 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
492  DialectAsmParser &parser) {
493  return parseAndVerifyInteger<unsigned>(dialect, parser);
494 }
495 
496 namespace {
497 // Functor object to parse a comma separated list of specs. The function
498 // parseAndVerify does the actual parsing and verification of individual
499 // elements. This is a functor since parsing the last element of the list
500 // (termination condition) needs partial specialization.
501 template <typename ParseType, typename... Args>
502 struct ParseCommaSeparatedList {
503  std::optional<std::tuple<ParseType, Args...>>
504  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
505  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
506  if (!parseVal)
507  return std::nullopt;
508 
509  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
510  if (numArgs != 0 && failed(parser.parseComma()))
511  return std::nullopt;
512  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
513  if (!remainingValues)
514  return std::nullopt;
515  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
516  remainingValues.value());
517  }
518 };
519 
520 // Partial specialization of the function to parse a comma separated list of
521 // specs to parse the last element of the list.
522 template <typename ParseType>
523 struct ParseCommaSeparatedList<ParseType> {
524  std::optional<std::tuple<ParseType>>
525  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
526  if (auto value = parseAndVerify<ParseType>(dialect, parser))
527  return std::tuple<ParseType>(*value);
528  return std::nullopt;
529  }
530 };
531 } // namespace
532 
533 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
534 //
535 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
536 //
537 // arrayed-info ::= `NonArrayed` | `Arrayed`
538 //
539 // sampling-info ::= `SingleSampled` | `MultiSampled`
540 //
541 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
542 //
543 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
544 //
545 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
546 // arrayed-info `,` sampling-info `,`
547 // sampler-use-info `,` format `>`
548 static Type parseImageType(SPIRVDialect const &dialect,
549  DialectAsmParser &parser) {
550  if (parser.parseLess())
551  return Type();
552 
553  auto value =
554  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
555  ImageSamplingInfo, ImageSamplerUseInfo,
556  ImageFormat>{}(dialect, parser);
557  if (!value)
558  return Type();
559 
560  if (parser.parseGreater())
561  return Type();
562  return ImageType::get(*value);
563 }
564 
565 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
566 static Type parseSampledImageType(SPIRVDialect const &dialect,
567  DialectAsmParser &parser) {
568  if (parser.parseLess())
569  return Type();
570 
571  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
572  if (!parsedType)
573  return Type();
574 
575  if (parser.parseGreater())
576  return Type();
577  return SampledImageType::get(parsedType);
578 }
579 
580 // Parse decorations associated with a member.
581 static ParseResult parseStructMemberDecorations(
582  SPIRVDialect const &dialect, DialectAsmParser &parser,
583  ArrayRef<Type> memberTypes,
586 
587  // Check if the first element is offset.
588  SMLoc offsetLoc = parser.getCurrentLocation();
589  StructType::OffsetInfo offset = 0;
590  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
591  if (offsetParseResult.has_value()) {
592  if (failed(*offsetParseResult))
593  return failure();
594 
595  if (offsetInfo.size() != memberTypes.size() - 1) {
596  return parser.emitError(offsetLoc,
597  "offset specification must be given for "
598  "all members");
599  }
600  offsetInfo.push_back(offset);
601  }
602 
603  // Check for no spirv::Decorations.
604  if (succeeded(parser.parseOptionalRSquare()))
605  return success();
606 
607  // If there was an offset, make sure to parse the comma.
608  if (offsetParseResult.has_value() && parser.parseComma())
609  return failure();
610 
611  // Check for spirv::Decorations.
612  auto parseDecorations = [&]() {
613  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
614  if (!memberDecoration)
615  return failure();
616 
617  // Parse member decoration value if it exists.
618  if (succeeded(parser.parseOptionalEqual())) {
619  auto memberDecorationValue =
620  parseAndVerifyInteger<uint32_t>(dialect, parser);
621 
622  if (!memberDecorationValue)
623  return failure();
624 
625  memberDecorationInfo.emplace_back(
626  static_cast<uint32_t>(memberTypes.size() - 1), 1,
627  memberDecoration.value(), memberDecorationValue.value());
628  } else {
629  memberDecorationInfo.emplace_back(
630  static_cast<uint32_t>(memberTypes.size() - 1), 0,
631  memberDecoration.value(), 0);
632  }
633  return success();
634  };
635  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
636  failed(parser.parseRSquare()))
637  return failure();
638 
639  return success();
640 }
641 
642 // struct-member-decoration ::= integer-literal? spirv-decoration*
643 // struct-type ::=
644 // `!spirv.struct<` (id `,`)?
645 // `(`
646 // (spirv-type (`[` struct-member-decoration `]`)?)*
647 // `)>`
648 static Type parseStructType(SPIRVDialect const &dialect,
649  DialectAsmParser &parser) {
650  // TODO: This function is quite lengthy. Break it down into smaller chunks.
651 
652  if (parser.parseLess())
653  return Type();
654 
655  StringRef identifier;
656  FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
657 
658  // Check if this is an identified struct type.
659  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
660  // Check if this is a possible recursive reference.
661  auto structType =
662  StructType::getIdentified(dialect.getContext(), identifier);
663  cyclicParse = parser.tryStartCyclicParse(structType);
664  if (succeeded(parser.parseOptionalGreater())) {
665  if (succeeded(cyclicParse)) {
666  parser.emitError(
667  parser.getNameLoc(),
668  "recursive struct reference not nested in struct definition");
669 
670  return Type();
671  }
672 
673  return structType;
674  }
675 
676  if (failed(parser.parseComma()))
677  return Type();
678 
679  if (failed(cyclicParse)) {
680  parser.emitError(parser.getNameLoc(),
681  "identifier already used for an enclosing struct");
682  return Type();
683  }
684  }
685 
686  if (failed(parser.parseLParen()))
687  return Type();
688 
689  if (succeeded(parser.parseOptionalRParen()) &&
690  succeeded(parser.parseOptionalGreater())) {
691  return StructType::getEmpty(dialect.getContext(), identifier);
692  }
693 
694  StructType idStructTy;
695 
696  if (!identifier.empty())
697  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
698 
699  SmallVector<Type, 4> memberTypes;
702 
703  do {
704  Type memberType;
705  if (parser.parseType(memberType))
706  return Type();
707  memberTypes.push_back(memberType);
708 
709  if (succeeded(parser.parseOptionalLSquare()))
710  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
711  memberDecorationInfo))
712  return Type();
713  } while (succeeded(parser.parseOptionalComma()));
714 
715  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
716  parser.emitError(parser.getNameLoc(),
717  "offset specification must be given for all members");
718  return Type();
719  }
720 
721  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
722  return Type();
723 
724  if (!identifier.empty()) {
725  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
726  memberDecorationInfo)))
727  return Type();
728  return idStructTy;
729  }
730 
731  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
732 }
733 
734 // spirv-type ::= array-type
735 // | element-type
736 // | image-type
737 // | pointer-type
738 // | runtime-array-type
739 // | sampled-image-type
740 // | struct-type
742  StringRef keyword;
743  if (parser.parseKeyword(&keyword))
744  return Type();
745 
746  if (keyword == "array")
747  return parseArrayType(*this, parser);
748  if (keyword == "coopmatrix")
749  return parseCooperativeMatrixType(*this, parser);
750  if (keyword == "image")
751  return parseImageType(*this, parser);
752  if (keyword == "ptr")
753  return parsePointerType(*this, parser);
754  if (keyword == "rtarray")
755  return parseRuntimeArrayType(*this, parser);
756  if (keyword == "sampled_image")
757  return parseSampledImageType(*this, parser);
758  if (keyword == "struct")
759  return parseStructType(*this, parser);
760  if (keyword == "matrix")
761  return parseMatrixType(*this, parser);
762  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
763  return Type();
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // Type Printing
768 //===----------------------------------------------------------------------===//
769 
770 static void print(ArrayType type, DialectAsmPrinter &os) {
771  os << "array<" << type.getNumElements() << " x " << type.getElementType();
772  if (unsigned stride = type.getArrayStride())
773  os << ", stride=" << stride;
774  os << ">";
775 }
776 
777 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
778  os << "rtarray<" << type.getElementType();
779  if (unsigned stride = type.getArrayStride())
780  os << ", stride=" << stride;
781  os << ">";
782 }
783 
784 static void print(PointerType type, DialectAsmPrinter &os) {
785  os << "ptr<" << type.getPointeeType() << ", "
786  << stringifyStorageClass(type.getStorageClass()) << ">";
787 }
788 
789 static void print(ImageType type, DialectAsmPrinter &os) {
790  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
791  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
792  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
793  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
794  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
795  << stringifyImageFormat(type.getImageFormat()) << ">";
796 }
797 
798 static void print(SampledImageType type, DialectAsmPrinter &os) {
799  os << "sampled_image<" << type.getImageType() << ">";
800 }
801 
802 static void print(StructType type, DialectAsmPrinter &os) {
803  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
804 
805  os << "struct<";
806 
807  if (type.isIdentified()) {
808  os << type.getIdentifier();
809 
810  cyclicPrint = os.tryStartCyclicPrint(type);
811  if (failed(cyclicPrint)) {
812  os << ">";
813  return;
814  }
815 
816  os << ", ";
817  }
818 
819  os << "(";
820 
821  auto printMember = [&](unsigned i) {
822  os << type.getElementType(i);
824  type.getMemberDecorations(i, decorations);
825  if (type.hasOffset() || !decorations.empty()) {
826  os << " [";
827  if (type.hasOffset()) {
828  os << type.getMemberOffset(i);
829  if (!decorations.empty())
830  os << ", ";
831  }
832  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
833  os << stringifyDecoration(decoration.decoration);
834  if (decoration.hasValue) {
835  os << "=" << decoration.decorationValue;
836  }
837  };
838  llvm::interleaveComma(decorations, os, eachFn);
839  os << "]";
840  }
841  };
842  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
843  printMember);
844  os << ")>";
845 }
846 
848  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
849  << type.getElementType() << ", " << type.getScope() << ", "
850  << type.getUse() << ">";
851 }
852 
853 static void print(MatrixType type, DialectAsmPrinter &os) {
854  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
855  os << ">";
856 }
857 
858 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
859  TypeSwitch<Type>(type)
862  [&](auto type) { print(type, os); })
863  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
864 }
865 
866 //===----------------------------------------------------------------------===//
867 // Constant
868 //===----------------------------------------------------------------------===//
869 
871  Attribute value, Type type,
872  Location loc) {
873  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
874  return builder.create<ub::PoisonOp>(loc, type, poison);
875 
876  if (!spirv::ConstantOp::isBuildableWith(type))
877  return nullptr;
878 
879  return builder.create<spirv::ConstantOp>(loc, type, value);
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // Shader Interface ABI
884 //===----------------------------------------------------------------------===//
885 
886 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
887  NamedAttribute attribute) {
888  StringRef symbol = attribute.getName().strref();
889  Attribute attr = attribute.getValue();
890 
891  if (symbol == spirv::getEntryPointABIAttrName()) {
892  if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
893  return op->emitError("'")
894  << symbol << "' attribute must be an entry point ABI attribute";
895  }
896  } else if (symbol == spirv::getTargetEnvAttrName()) {
897  if (!llvm::isa<spirv::TargetEnvAttr>(attr))
898  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
899  } else {
900  return op->emitError("found unsupported '")
901  << symbol << "' attribute on operation";
902  }
903 
904  return success();
905 }
906 
907 /// Verifies the given SPIR-V `attribute` attached to a value of the given
908 /// `valueType` is valid.
909 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
910  NamedAttribute attribute) {
911  StringRef symbol = attribute.getName().strref();
912  Attribute attr = attribute.getValue();
913 
914  if (symbol == spirv::getInterfaceVarABIAttrName()) {
915  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
916  if (!varABIAttr)
917  return emitError(loc, "'")
918  << symbol << "' must be a spirv::InterfaceVarABIAttr";
919 
920  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
921  return emitError(loc, "'") << symbol
922  << "' attribute cannot specify storage class "
923  "when attaching to a non-scalar value";
924  return success();
925  }
926  if (symbol == spirv::DecorationAttr::name) {
927  if (!isa<spirv::DecorationAttr>(attr))
928  return emitError(loc, "'")
929  << symbol << "' must be a spirv::DecorationAttr";
930  return success();
931  }
932 
933  return emitError(loc, "found unsupported '")
934  << symbol << "' attribute on region argument";
935 }
936 
937 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
938  unsigned regionIndex,
939  unsigned argIndex,
940  NamedAttribute attribute) {
941  auto funcOp = dyn_cast<FunctionOpInterface>(op);
942  if (!funcOp)
943  return success();
944  Type argType = funcOp.getArgumentTypes()[argIndex];
945 
946  return verifyRegionAttribute(op->getLoc(), argType, attribute);
947 }
948 
949 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
950  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
951  NamedAttribute attribute) {
952  return op->emitError("cannot attach SPIR-V attributes to region result");
953 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:261
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:247
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:252
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:235
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:263
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:370
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:372
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:384
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:380
Type getElementType() const
Definition: SPIRVTypes.cpp:366
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:376
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:431
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:433
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:494
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:484
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:767
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:528
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:996
uint64_t getMemberOffset(unsigned) const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.