MLIR  20.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::spirv;
38 
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48  return llvm::any_of(region, [](Block &block) {
49  Operation *terminator = block.getTerminator();
50  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51  });
52 }
53 
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
58 
59  /// All call operations within SPIRV can be inlined.
60  bool isLegalToInline(Operation *call, Operation *callable,
61  bool wouldBeCloned) const final {
62  return true;
63  }
64 
65  /// Returns true if the given region 'src' can be inlined into the region
66  /// 'dest' that is attached to an operation registered to the current dialect.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &) const final {
69  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70  // spirv.mlir.loop operations.
71  auto *op = dest->getParentOp();
72  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73  }
74 
75  /// Returns true if the given operation 'op', that is registered to this
76  /// dialect, can be inlined into the region 'dest' that is attached to an
77  /// operation registered to the current dialect.
78  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79  IRMapping &) const final {
80  // TODO: Enable inlining structured control flows with return.
81  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82  containsReturn(op->getRegion(0)))
83  return false;
84  // TODO: we need to filter OpKill here to avoid inlining it to
85  // a loop continue construct:
86  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87  // However OpKill is fragment shader specific and we don't support it yet.
88  return true;
89  }
90 
91  /// Handle the given inlined terminator by replacing it with a new operation
92  /// as necessary.
93  void handleTerminator(Operation *op, Block *newDest) const final {
94  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
96  op->erase();
97  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
98  OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
99  retValOp->getOperands());
100  op->erase();
101  }
102  }
103 
104  /// Handle the given inlined terminator by replacing it with a new operation
105  /// as necessary.
106  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107  // Only spirv.ReturnValue needs to be handled here.
108  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
109  if (!retValOp)
110  return;
111 
112  // Replace the values directly with the return operands.
113  assert(valuesToRepl.size() == 1 &&
114  "spirv.ReturnValue expected to only handle one result");
115  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
116  }
117 };
118 } // namespace
119 
120 //===----------------------------------------------------------------------===//
121 // SPIR-V Dialect
122 //===----------------------------------------------------------------------===//
123 
124 void SPIRVDialect::initialize() {
125  registerAttributes();
126  registerTypes();
127 
128  // Add SPIR-V ops.
129  addOperations<
130 #define GET_OP_LIST
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132  >();
133 
134  addInterfaces<SPIRVInlinerInterface>();
135 
136  // Allow unknown operations because SPIR-V is extensible.
137  allowUnknownOperations();
138  declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
139 }
140 
141 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // Type Parsing
147 //===----------------------------------------------------------------------===//
148 
149 // Forward declarations.
150 template <typename ValTy>
151 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152  DialectAsmParser &parser);
153 template <>
154 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155  DialectAsmParser &parser);
156 
157 template <>
158 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159  DialectAsmParser &parser);
160 
161 static Type parseAndVerifyType(SPIRVDialect const &dialect,
162  DialectAsmParser &parser) {
163  Type type;
164  SMLoc typeLoc = parser.getCurrentLocation();
165  if (parser.parseType(type))
166  return Type();
167 
168  // Allow SPIR-V dialect types
169  if (&type.getDialect() == &dialect)
170  return type;
171 
172  // Check other allowed types
173  if (auto t = llvm::dyn_cast<FloatType>(type)) {
174  if (type.isBF16()) {
175  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
176  return Type();
177  }
178  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
179  if (!ScalarType::isValid(t)) {
180  parser.emitError(typeLoc,
181  "only 1/8/16/32/64-bit integer type allowed but found ")
182  << type;
183  return Type();
184  }
185  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
186  if (t.getRank() != 1) {
187  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
188  return Type();
189  }
190  if (t.getNumElements() > 4) {
191  parser.emitError(
192  typeLoc, "vector length has to be less than or equal to 4 but found ")
193  << t.getNumElements();
194  return Type();
195  }
196  } else {
197  parser.emitError(typeLoc, "cannot use ")
198  << type << " to compose SPIR-V types";
199  return Type();
200  }
201 
202  return type;
203 }
204 
205 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
206  DialectAsmParser &parser) {
207  Type type;
208  SMLoc typeLoc = parser.getCurrentLocation();
209  if (parser.parseType(type))
210  return Type();
211 
212  if (auto t = llvm::dyn_cast<VectorType>(type)) {
213  if (t.getRank() != 1) {
214  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215  return Type();
216  }
217  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
218  parser.emitError(typeLoc,
219  "matrix columns size has to be less than or equal "
220  "to 4 and greater than or equal 2, but found ")
221  << t.getNumElements();
222  return Type();
223  }
224 
225  if (!llvm::isa<FloatType>(t.getElementType())) {
226  parser.emitError(typeLoc, "matrix columns' elements must be of "
227  "Float type, got ")
228  << t.getElementType();
229  return Type();
230  }
231  } else {
232  parser.emitError(typeLoc, "matrix must be composed using vector "
233  "type, got ")
234  << type;
235  return Type();
236  }
237 
238  return type;
239 }
240 
241 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
242  DialectAsmParser &parser) {
243  Type type;
244  SMLoc typeLoc = parser.getCurrentLocation();
245  if (parser.parseType(type))
246  return Type();
247 
248  if (!llvm::isa<ImageType>(type)) {
249  parser.emitError(typeLoc,
250  "sampled image must be composed using image type, got ")
251  << type;
252  return Type();
253  }
254 
255  return type;
256 }
257 
258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
260 /// missing.
261 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
262  DialectAsmParser &parser,
263  unsigned &stride) {
264  if (failed(parser.parseOptionalComma())) {
265  stride = 0;
266  return success();
267  }
268 
269  if (parser.parseKeyword("stride") || parser.parseEqual())
270  return failure();
271 
272  SMLoc strideLoc = parser.getCurrentLocation();
273  std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
274  if (!optStride)
275  return failure();
276 
277  if (!(stride = *optStride)) {
278  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
279  return failure();
280  }
281  return success();
282 }
283 
284 // element-type ::= integer-type
285 // | floating-point-type
286 // | vector-type
287 // | spirv-type
288 //
289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290 // (`,` `stride` `=` integer-literal)? `>`
291 static Type parseArrayType(SPIRVDialect const &dialect,
292  DialectAsmParser &parser) {
293  if (parser.parseLess())
294  return Type();
295 
296  SmallVector<int64_t, 1> countDims;
297  SMLoc countLoc = parser.getCurrentLocation();
298  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
299  return Type();
300  if (countDims.size() != 1) {
301  parser.emitError(countLoc,
302  "expected single integer for array element count");
303  return Type();
304  }
305 
306  // According to the SPIR-V spec:
307  // "Length is the number of elements in the array. It must be at least 1."
308  int64_t count = countDims[0];
309  if (count == 0) {
310  parser.emitError(countLoc, "expected array length greater than 0");
311  return Type();
312  }
313 
314  Type elementType = parseAndVerifyType(dialect, parser);
315  if (!elementType)
316  return Type();
317 
318  unsigned stride = 0;
319  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
320  return Type();
321 
322  if (parser.parseGreater())
323  return Type();
324  return ArrayType::get(elementType, count, stride);
325 }
326 
327 // cooperative-matrix-type ::=
328 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
329 // scope `,` use `>`
330 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
331  DialectAsmParser &parser) {
332  if (parser.parseLess())
333  return {};
334 
336  SMLoc countLoc = parser.getCurrentLocation();
337  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
338  return {};
339 
340  if (dims.size() != 2) {
341  parser.emitError(countLoc, "expected row and column count");
342  return {};
343  }
344 
345  auto elementTy = parseAndVerifyType(dialect, parser);
346  if (!elementTy)
347  return {};
348 
349  Scope scope;
350  if (parser.parseComma() ||
351  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
352  return {};
353 
354  CooperativeMatrixUseKHR use;
355  if (parser.parseComma() ||
356  spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
357  return {};
358 
359  if (parser.parseGreater())
360  return {};
361 
362  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
363 }
364 
365 // TODO: Reorder methods to be utilities first and parse*Type
366 // methods in alphabetical order
367 //
368 // storage-class ::= `UniformConstant`
369 // | `Uniform`
370 // | `Workgroup`
371 // | <and other storage classes...>
372 //
373 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
374 static Type parsePointerType(SPIRVDialect const &dialect,
375  DialectAsmParser &parser) {
376  if (parser.parseLess())
377  return Type();
378 
379  auto pointeeType = parseAndVerifyType(dialect, parser);
380  if (!pointeeType)
381  return Type();
382 
383  StringRef storageClassSpec;
384  SMLoc storageClassLoc = parser.getCurrentLocation();
385  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
386  return Type();
387 
388  auto storageClass = symbolizeStorageClass(storageClassSpec);
389  if (!storageClass) {
390  parser.emitError(storageClassLoc, "unknown storage class: ")
391  << storageClassSpec;
392  return Type();
393  }
394  if (parser.parseGreater())
395  return Type();
396  return PointerType::get(pointeeType, *storageClass);
397 }
398 
399 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
400 // (`,` `stride` `=` integer-literal)? `>`
401 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
402  DialectAsmParser &parser) {
403  if (parser.parseLess())
404  return Type();
405 
406  Type elementType = parseAndVerifyType(dialect, parser);
407  if (!elementType)
408  return Type();
409 
410  unsigned stride = 0;
411  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
412  return Type();
413 
414  if (parser.parseGreater())
415  return Type();
416  return RuntimeArrayType::get(elementType, stride);
417 }
418 
419 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
420 static Type parseMatrixType(SPIRVDialect const &dialect,
421  DialectAsmParser &parser) {
422  if (parser.parseLess())
423  return Type();
424 
425  SmallVector<int64_t, 1> countDims;
426  SMLoc countLoc = parser.getCurrentLocation();
427  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
428  return Type();
429  if (countDims.size() != 1) {
430  parser.emitError(countLoc, "expected single unsigned "
431  "integer for number of columns");
432  return Type();
433  }
434 
435  int64_t columnCount = countDims[0];
436  // According to the specification, Matrices can have 2, 3, or 4 columns
437  if (columnCount < 2 || columnCount > 4) {
438  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
439  "columns");
440  return Type();
441  }
442 
443  Type columnType = parseAndVerifyMatrixType(dialect, parser);
444  if (!columnType)
445  return Type();
446 
447  if (parser.parseGreater())
448  return Type();
449 
450  return MatrixType::get(columnType, columnCount);
451 }
452 
453 // Specialize this function to parse each of the parameters that define an
454 // ImageType. By default it assumes this is an enum type.
455 template <typename ValTy>
456 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
457  DialectAsmParser &parser) {
458  StringRef enumSpec;
459  SMLoc enumLoc = parser.getCurrentLocation();
460  if (parser.parseKeyword(&enumSpec)) {
461  return std::nullopt;
462  }
463 
464  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
465  if (!val)
466  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
467  return val;
468 }
469 
470 template <>
471 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
472  DialectAsmParser &parser) {
473  // TODO: Further verify that the element type can be sampled
474  auto ty = parseAndVerifyType(dialect, parser);
475  if (!ty)
476  return std::nullopt;
477  return ty;
478 }
479 
480 template <typename IntTy>
481 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
482  DialectAsmParser &parser) {
483  IntTy offsetVal = std::numeric_limits<IntTy>::max();
484  if (parser.parseInteger(offsetVal))
485  return std::nullopt;
486  return offsetVal;
487 }
488 
489 template <>
490 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
491  DialectAsmParser &parser) {
492  return parseAndVerifyInteger<unsigned>(dialect, parser);
493 }
494 
495 namespace {
496 // Functor object to parse a comma separated list of specs. The function
497 // parseAndVerify does the actual parsing and verification of individual
498 // elements. This is a functor since parsing the last element of the list
499 // (termination condition) needs partial specialization.
500 template <typename ParseType, typename... Args>
501 struct ParseCommaSeparatedList {
502  std::optional<std::tuple<ParseType, Args...>>
503  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
504  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
505  if (!parseVal)
506  return std::nullopt;
507 
508  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
509  if (numArgs != 0 && failed(parser.parseComma()))
510  return std::nullopt;
511  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
512  if (!remainingValues)
513  return std::nullopt;
514  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
515  remainingValues.value());
516  }
517 };
518 
519 // Partial specialization of the function to parse a comma separated list of
520 // specs to parse the last element of the list.
521 template <typename ParseType>
522 struct ParseCommaSeparatedList<ParseType> {
523  std::optional<std::tuple<ParseType>>
524  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
525  if (auto value = parseAndVerify<ParseType>(dialect, parser))
526  return std::tuple<ParseType>(*value);
527  return std::nullopt;
528  }
529 };
530 } // namespace
531 
532 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
533 //
534 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
535 //
536 // arrayed-info ::= `NonArrayed` | `Arrayed`
537 //
538 // sampling-info ::= `SingleSampled` | `MultiSampled`
539 //
540 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
541 //
542 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
543 //
544 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
545 // arrayed-info `,` sampling-info `,`
546 // sampler-use-info `,` format `>`
547 static Type parseImageType(SPIRVDialect const &dialect,
548  DialectAsmParser &parser) {
549  if (parser.parseLess())
550  return Type();
551 
552  auto value =
553  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
554  ImageSamplingInfo, ImageSamplerUseInfo,
555  ImageFormat>{}(dialect, parser);
556  if (!value)
557  return Type();
558 
559  if (parser.parseGreater())
560  return Type();
561  return ImageType::get(*value);
562 }
563 
564 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
565 static Type parseSampledImageType(SPIRVDialect const &dialect,
566  DialectAsmParser &parser) {
567  if (parser.parseLess())
568  return Type();
569 
570  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
571  if (!parsedType)
572  return Type();
573 
574  if (parser.parseGreater())
575  return Type();
576  return SampledImageType::get(parsedType);
577 }
578 
579 // Parse decorations associated with a member.
580 static ParseResult parseStructMemberDecorations(
581  SPIRVDialect const &dialect, DialectAsmParser &parser,
582  ArrayRef<Type> memberTypes,
585 
586  // Check if the first element is offset.
587  SMLoc offsetLoc = parser.getCurrentLocation();
588  StructType::OffsetInfo offset = 0;
589  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
590  if (offsetParseResult.has_value()) {
591  if (failed(*offsetParseResult))
592  return failure();
593 
594  if (offsetInfo.size() != memberTypes.size() - 1) {
595  return parser.emitError(offsetLoc,
596  "offset specification must be given for "
597  "all members");
598  }
599  offsetInfo.push_back(offset);
600  }
601 
602  // Check for no spirv::Decorations.
603  if (succeeded(parser.parseOptionalRSquare()))
604  return success();
605 
606  // If there was an offset, make sure to parse the comma.
607  if (offsetParseResult.has_value() && parser.parseComma())
608  return failure();
609 
610  // Check for spirv::Decorations.
611  auto parseDecorations = [&]() {
612  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
613  if (!memberDecoration)
614  return failure();
615 
616  // Parse member decoration value if it exists.
617  if (succeeded(parser.parseOptionalEqual())) {
618  auto memberDecorationValue =
619  parseAndVerifyInteger<uint32_t>(dialect, parser);
620 
621  if (!memberDecorationValue)
622  return failure();
623 
624  memberDecorationInfo.emplace_back(
625  static_cast<uint32_t>(memberTypes.size() - 1), 1,
626  memberDecoration.value(), memberDecorationValue.value());
627  } else {
628  memberDecorationInfo.emplace_back(
629  static_cast<uint32_t>(memberTypes.size() - 1), 0,
630  memberDecoration.value(), 0);
631  }
632  return success();
633  };
634  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
635  failed(parser.parseRSquare()))
636  return failure();
637 
638  return success();
639 }
640 
641 // struct-member-decoration ::= integer-literal? spirv-decoration*
642 // struct-type ::=
643 // `!spirv.struct<` (id `,`)?
644 // `(`
645 // (spirv-type (`[` struct-member-decoration `]`)?)*
646 // `)>`
647 static Type parseStructType(SPIRVDialect const &dialect,
648  DialectAsmParser &parser) {
649  // TODO: This function is quite lengthy. Break it down into smaller chunks.
650 
651  if (parser.parseLess())
652  return Type();
653 
654  StringRef identifier;
655  FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
656 
657  // Check if this is an identified struct type.
658  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
659  // Check if this is a possible recursive reference.
660  auto structType =
661  StructType::getIdentified(dialect.getContext(), identifier);
662  cyclicParse = parser.tryStartCyclicParse(structType);
663  if (succeeded(parser.parseOptionalGreater())) {
664  if (succeeded(cyclicParse)) {
665  parser.emitError(
666  parser.getNameLoc(),
667  "recursive struct reference not nested in struct definition");
668 
669  return Type();
670  }
671 
672  return structType;
673  }
674 
675  if (failed(parser.parseComma()))
676  return Type();
677 
678  if (failed(cyclicParse)) {
679  parser.emitError(parser.getNameLoc(),
680  "identifier already used for an enclosing struct");
681  return Type();
682  }
683  }
684 
685  if (failed(parser.parseLParen()))
686  return Type();
687 
688  if (succeeded(parser.parseOptionalRParen()) &&
689  succeeded(parser.parseOptionalGreater())) {
690  return StructType::getEmpty(dialect.getContext(), identifier);
691  }
692 
693  StructType idStructTy;
694 
695  if (!identifier.empty())
696  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
697 
698  SmallVector<Type, 4> memberTypes;
701 
702  do {
703  Type memberType;
704  if (parser.parseType(memberType))
705  return Type();
706  memberTypes.push_back(memberType);
707 
708  if (succeeded(parser.parseOptionalLSquare()))
709  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
710  memberDecorationInfo))
711  return Type();
712  } while (succeeded(parser.parseOptionalComma()));
713 
714  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
715  parser.emitError(parser.getNameLoc(),
716  "offset specification must be given for all members");
717  return Type();
718  }
719 
720  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
721  return Type();
722 
723  if (!identifier.empty()) {
724  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
725  memberDecorationInfo)))
726  return Type();
727  return idStructTy;
728  }
729 
730  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
731 }
732 
733 // spirv-type ::= array-type
734 // | element-type
735 // | image-type
736 // | pointer-type
737 // | runtime-array-type
738 // | sampled-image-type
739 // | struct-type
741  StringRef keyword;
742  if (parser.parseKeyword(&keyword))
743  return Type();
744 
745  if (keyword == "array")
746  return parseArrayType(*this, parser);
747  if (keyword == "coopmatrix")
748  return parseCooperativeMatrixType(*this, parser);
749  if (keyword == "image")
750  return parseImageType(*this, parser);
751  if (keyword == "ptr")
752  return parsePointerType(*this, parser);
753  if (keyword == "rtarray")
754  return parseRuntimeArrayType(*this, parser);
755  if (keyword == "sampled_image")
756  return parseSampledImageType(*this, parser);
757  if (keyword == "struct")
758  return parseStructType(*this, parser);
759  if (keyword == "matrix")
760  return parseMatrixType(*this, parser);
761  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
762  return Type();
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // Type Printing
767 //===----------------------------------------------------------------------===//
768 
769 static void print(ArrayType type, DialectAsmPrinter &os) {
770  os << "array<" << type.getNumElements() << " x " << type.getElementType();
771  if (unsigned stride = type.getArrayStride())
772  os << ", stride=" << stride;
773  os << ">";
774 }
775 
776 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
777  os << "rtarray<" << type.getElementType();
778  if (unsigned stride = type.getArrayStride())
779  os << ", stride=" << stride;
780  os << ">";
781 }
782 
783 static void print(PointerType type, DialectAsmPrinter &os) {
784  os << "ptr<" << type.getPointeeType() << ", "
785  << stringifyStorageClass(type.getStorageClass()) << ">";
786 }
787 
788 static void print(ImageType type, DialectAsmPrinter &os) {
789  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
790  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
791  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
792  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
793  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
794  << stringifyImageFormat(type.getImageFormat()) << ">";
795 }
796 
797 static void print(SampledImageType type, DialectAsmPrinter &os) {
798  os << "sampled_image<" << type.getImageType() << ">";
799 }
800 
801 static void print(StructType type, DialectAsmPrinter &os) {
802  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
803 
804  os << "struct<";
805 
806  if (type.isIdentified()) {
807  os << type.getIdentifier();
808 
809  cyclicPrint = os.tryStartCyclicPrint(type);
810  if (failed(cyclicPrint)) {
811  os << ">";
812  return;
813  }
814 
815  os << ", ";
816  }
817 
818  os << "(";
819 
820  auto printMember = [&](unsigned i) {
821  os << type.getElementType(i);
823  type.getMemberDecorations(i, decorations);
824  if (type.hasOffset() || !decorations.empty()) {
825  os << " [";
826  if (type.hasOffset()) {
827  os << type.getMemberOffset(i);
828  if (!decorations.empty())
829  os << ", ";
830  }
831  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
832  os << stringifyDecoration(decoration.decoration);
833  if (decoration.hasValue) {
834  os << "=" << decoration.decorationValue;
835  }
836  };
837  llvm::interleaveComma(decorations, os, eachFn);
838  os << "]";
839  }
840  };
841  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
842  printMember);
843  os << ")>";
844 }
845 
847  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
848  << type.getElementType() << ", " << type.getScope() << ", "
849  << type.getUse() << ">";
850 }
851 
852 static void print(MatrixType type, DialectAsmPrinter &os) {
853  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
854  os << ">";
855 }
856 
857 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
858  TypeSwitch<Type>(type)
861  [&](auto type) { print(type, os); })
862  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // Constant
867 //===----------------------------------------------------------------------===//
868 
870  Attribute value, Type type,
871  Location loc) {
872  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
873  return builder.create<ub::PoisonOp>(loc, type, poison);
874 
875  if (!spirv::ConstantOp::isBuildableWith(type))
876  return nullptr;
877 
878  return builder.create<spirv::ConstantOp>(loc, type, value);
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // Shader Interface ABI
883 //===----------------------------------------------------------------------===//
884 
885 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
886  NamedAttribute attribute) {
887  StringRef symbol = attribute.getName().strref();
888  Attribute attr = attribute.getValue();
889 
890  if (symbol == spirv::getEntryPointABIAttrName()) {
891  if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
892  return op->emitError("'")
893  << symbol << "' attribute must be an entry point ABI attribute";
894  }
895  } else if (symbol == spirv::getTargetEnvAttrName()) {
896  if (!llvm::isa<spirv::TargetEnvAttr>(attr))
897  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
898  } else {
899  return op->emitError("found unsupported '")
900  << symbol << "' attribute on operation";
901  }
902 
903  return success();
904 }
905 
906 /// Verifies the given SPIR-V `attribute` attached to a value of the given
907 /// `valueType` is valid.
908 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
909  NamedAttribute attribute) {
910  StringRef symbol = attribute.getName().strref();
911  Attribute attr = attribute.getValue();
912 
913  if (symbol == spirv::getInterfaceVarABIAttrName()) {
914  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
915  if (!varABIAttr)
916  return emitError(loc, "'")
917  << symbol << "' must be a spirv::InterfaceVarABIAttr";
918 
919  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
920  return emitError(loc, "'") << symbol
921  << "' attribute cannot specify storage class "
922  "when attaching to a non-scalar value";
923  return success();
924  }
925  if (symbol == spirv::DecorationAttr::name) {
926  if (!isa<spirv::DecorationAttr>(attr))
927  return emitError(loc, "'")
928  << symbol << "' must be a spirv::DecorationAttr";
929  return success();
930  }
931 
932  return emitError(loc, "found unsupported '")
933  << symbol << "' attribute on region argument";
934 }
935 
936 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
937  unsigned regionIndex,
938  unsigned argIndex,
939  NamedAttribute attribute) {
940  auto funcOp = dyn_cast<FunctionOpInterface>(op);
941  if (!funcOp)
942  return success();
943  Type argType = funcOp.getArgumentTypes()[argIndex];
944 
945  return verifyRegionAttribute(op->getLoc(), argType, attribute);
946 }
947 
948 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
949  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
950  NamedAttribute attribute) {
951  return op->emitError("cannot attach SPIR-V attributes to region result");
952 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
bool isBF16() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:240
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:222
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:242
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:349
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:351
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:363
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:359
Type getElementType() const
Definition: SPIRVTypes.cpp:345
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:355
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:410
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:412
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:473
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:732
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:507
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:974
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:998
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:984
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
uint64_t getMemberOffset(unsigned) const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.