MLIR  19.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::spirv;
38 
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48  return llvm::any_of(region, [](Block &block) {
49  Operation *terminator = block.getTerminator();
50  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51  });
52 }
53 
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
58 
59  /// All call operations within SPIRV can be inlined.
60  bool isLegalToInline(Operation *call, Operation *callable,
61  bool wouldBeCloned) const final {
62  return true;
63  }
64 
65  /// Returns true if the given region 'src' can be inlined into the region
66  /// 'dest' that is attached to an operation registered to the current dialect.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &) const final {
69  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70  // spirv.mlir.loop operations.
71  auto *op = dest->getParentOp();
72  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73  }
74 
75  /// Returns true if the given operation 'op', that is registered to this
76  /// dialect, can be inlined into the region 'dest' that is attached to an
77  /// operation registered to the current dialect.
78  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79  IRMapping &) const final {
80  // TODO: Enable inlining structured control flows with return.
81  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82  containsReturn(op->getRegion(0)))
83  return false;
84  // TODO: we need to filter OpKill here to avoid inlining it to
85  // a loop continue construct:
86  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87  // However OpKill is fragment shader specific and we don't support it yet.
88  return true;
89  }
90 
91  /// Handle the given inlined terminator by replacing it with a new operation
92  /// as necessary.
93  void handleTerminator(Operation *op, Block *newDest) const final {
94  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
96  op->erase();
97  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
98  OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
99  retValOp->getOperands());
100  op->erase();
101  }
102  }
103 
104  /// Handle the given inlined terminator by replacing it with a new operation
105  /// as necessary.
106  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107  // Only spirv.ReturnValue needs to be handled here.
108  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
109  if (!retValOp)
110  return;
111 
112  // Replace the values directly with the return operands.
113  assert(valuesToRepl.size() == 1 &&
114  "spirv.ReturnValue expected to only handle one result");
115  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
116  }
117 };
118 } // namespace
119 
120 //===----------------------------------------------------------------------===//
121 // SPIR-V Dialect
122 //===----------------------------------------------------------------------===//
123 
124 void SPIRVDialect::initialize() {
125  registerAttributes();
126  registerTypes();
127 
128  // Add SPIR-V ops.
129  addOperations<
130 #define GET_OP_LIST
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132  >();
133 
134  addInterfaces<SPIRVInlinerInterface>();
135 
136  // Allow unknown operations because SPIR-V is extensible.
137  allowUnknownOperations();
138  declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
139 }
140 
141 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // Type Parsing
147 //===----------------------------------------------------------------------===//
148 
149 // Forward declarations.
150 template <typename ValTy>
151 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152  DialectAsmParser &parser);
153 template <>
154 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155  DialectAsmParser &parser);
156 
157 template <>
158 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159  DialectAsmParser &parser);
160 
161 static Type parseAndVerifyType(SPIRVDialect const &dialect,
162  DialectAsmParser &parser) {
163  Type type;
164  SMLoc typeLoc = parser.getCurrentLocation();
165  if (parser.parseType(type))
166  return Type();
167 
168  // Allow SPIR-V dialect types
169  if (&type.getDialect() == &dialect)
170  return type;
171 
172  // Check other allowed types
173  if (auto t = llvm::dyn_cast<FloatType>(type)) {
174  if (type.isBF16()) {
175  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
176  return Type();
177  }
178  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
179  if (!ScalarType::isValid(t)) {
180  parser.emitError(typeLoc,
181  "only 1/8/16/32/64-bit integer type allowed but found ")
182  << type;
183  return Type();
184  }
185  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
186  if (t.getRank() != 1) {
187  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
188  return Type();
189  }
190  if (t.getNumElements() > 4) {
191  parser.emitError(
192  typeLoc, "vector length has to be less than or equal to 4 but found ")
193  << t.getNumElements();
194  return Type();
195  }
196  } else {
197  parser.emitError(typeLoc, "cannot use ")
198  << type << " to compose SPIR-V types";
199  return Type();
200  }
201 
202  return type;
203 }
204 
205 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
206  DialectAsmParser &parser) {
207  Type type;
208  SMLoc typeLoc = parser.getCurrentLocation();
209  if (parser.parseType(type))
210  return Type();
211 
212  if (auto t = llvm::dyn_cast<VectorType>(type)) {
213  if (t.getRank() != 1) {
214  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215  return Type();
216  }
217  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
218  parser.emitError(typeLoc,
219  "matrix columns size has to be less than or equal "
220  "to 4 and greater than or equal 2, but found ")
221  << t.getNumElements();
222  return Type();
223  }
224 
225  if (!llvm::isa<FloatType>(t.getElementType())) {
226  parser.emitError(typeLoc, "matrix columns' elements must be of "
227  "Float type, got ")
228  << t.getElementType();
229  return Type();
230  }
231  } else {
232  parser.emitError(typeLoc, "matrix must be composed using vector "
233  "type, got ")
234  << type;
235  return Type();
236  }
237 
238  return type;
239 }
240 
241 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
242  DialectAsmParser &parser) {
243  Type type;
244  SMLoc typeLoc = parser.getCurrentLocation();
245  if (parser.parseType(type))
246  return Type();
247 
248  if (!llvm::isa<ImageType>(type)) {
249  parser.emitError(typeLoc,
250  "sampled image must be composed using image type, got ")
251  << type;
252  return Type();
253  }
254 
255  return type;
256 }
257 
258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
260 /// missing.
261 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
262  DialectAsmParser &parser,
263  unsigned &stride) {
264  if (failed(parser.parseOptionalComma())) {
265  stride = 0;
266  return success();
267  }
268 
269  if (parser.parseKeyword("stride") || parser.parseEqual())
270  return failure();
271 
272  SMLoc strideLoc = parser.getCurrentLocation();
273  std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
274  if (!optStride)
275  return failure();
276 
277  if (!(stride = *optStride)) {
278  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
279  return failure();
280  }
281  return success();
282 }
283 
284 // element-type ::= integer-type
285 // | floating-point-type
286 // | vector-type
287 // | spirv-type
288 //
289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290 // (`,` `stride` `=` integer-literal)? `>`
291 static Type parseArrayType(SPIRVDialect const &dialect,
292  DialectAsmParser &parser) {
293  if (parser.parseLess())
294  return Type();
295 
296  SmallVector<int64_t, 1> countDims;
297  SMLoc countLoc = parser.getCurrentLocation();
298  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
299  return Type();
300  if (countDims.size() != 1) {
301  parser.emitError(countLoc,
302  "expected single integer for array element count");
303  return Type();
304  }
305 
306  // According to the SPIR-V spec:
307  // "Length is the number of elements in the array. It must be at least 1."
308  int64_t count = countDims[0];
309  if (count == 0) {
310  parser.emitError(countLoc, "expected array length greater than 0");
311  return Type();
312  }
313 
314  Type elementType = parseAndVerifyType(dialect, parser);
315  if (!elementType)
316  return Type();
317 
318  unsigned stride = 0;
319  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
320  return Type();
321 
322  if (parser.parseGreater())
323  return Type();
324  return ArrayType::get(elementType, count, stride);
325 }
326 
327 // cooperative-matrix-type ::=
328 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
329 // scope `,` use `>`
330 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
331  DialectAsmParser &parser) {
332  if (parser.parseLess())
333  return {};
334 
336  SMLoc countLoc = parser.getCurrentLocation();
337  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
338  return {};
339 
340  if (dims.size() != 2) {
341  parser.emitError(countLoc, "expected row and column count");
342  return {};
343  }
344 
345  auto elementTy = parseAndVerifyType(dialect, parser);
346  if (!elementTy)
347  return {};
348 
349  Scope scope;
350  if (parser.parseComma() ||
351  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
352  return {};
353 
354  CooperativeMatrixUseKHR use;
355  if (parser.parseComma() ||
356  spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
357  return {};
358 
359  if (parser.parseGreater())
360  return {};
361 
362  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
363 }
364 
365 // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
366 // element-type
367 // `,` layout `,` scope`>`
368 static Type parseJointMatrixType(SPIRVDialect const &dialect,
369  DialectAsmParser &parser) {
370  if (parser.parseLess())
371  return Type();
372 
374  SMLoc countLoc = parser.getCurrentLocation();
375  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
376  return Type();
377 
378  if (dims.size() != 2) {
379  parser.emitError(countLoc, "expected rows and columns size");
380  return Type();
381  }
382 
383  auto elementTy = parseAndVerifyType(dialect, parser);
384  if (!elementTy)
385  return Type();
386  MatrixLayout matrixLayout;
387  if (parser.parseComma() ||
388  spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
389  return Type();
390  Scope scope;
391  if (parser.parseComma() ||
392  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
393  return Type();
394  if (parser.parseGreater())
395  return Type();
396  return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
397  matrixLayout);
398 }
399 
400 // TODO: Reorder methods to be utilities first and parse*Type
401 // methods in alphabetical order
402 //
403 // storage-class ::= `UniformConstant`
404 // | `Uniform`
405 // | `Workgroup`
406 // | <and other storage classes...>
407 //
408 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
409 static Type parsePointerType(SPIRVDialect const &dialect,
410  DialectAsmParser &parser) {
411  if (parser.parseLess())
412  return Type();
413 
414  auto pointeeType = parseAndVerifyType(dialect, parser);
415  if (!pointeeType)
416  return Type();
417 
418  StringRef storageClassSpec;
419  SMLoc storageClassLoc = parser.getCurrentLocation();
420  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
421  return Type();
422 
423  auto storageClass = symbolizeStorageClass(storageClassSpec);
424  if (!storageClass) {
425  parser.emitError(storageClassLoc, "unknown storage class: ")
426  << storageClassSpec;
427  return Type();
428  }
429  if (parser.parseGreater())
430  return Type();
431  return PointerType::get(pointeeType, *storageClass);
432 }
433 
434 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
435 // (`,` `stride` `=` integer-literal)? `>`
436 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
437  DialectAsmParser &parser) {
438  if (parser.parseLess())
439  return Type();
440 
441  Type elementType = parseAndVerifyType(dialect, parser);
442  if (!elementType)
443  return Type();
444 
445  unsigned stride = 0;
446  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
447  return Type();
448 
449  if (parser.parseGreater())
450  return Type();
451  return RuntimeArrayType::get(elementType, stride);
452 }
453 
454 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
455 static Type parseMatrixType(SPIRVDialect const &dialect,
456  DialectAsmParser &parser) {
457  if (parser.parseLess())
458  return Type();
459 
460  SmallVector<int64_t, 1> countDims;
461  SMLoc countLoc = parser.getCurrentLocation();
462  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
463  return Type();
464  if (countDims.size() != 1) {
465  parser.emitError(countLoc, "expected single unsigned "
466  "integer for number of columns");
467  return Type();
468  }
469 
470  int64_t columnCount = countDims[0];
471  // According to the specification, Matrices can have 2, 3, or 4 columns
472  if (columnCount < 2 || columnCount > 4) {
473  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
474  "columns");
475  return Type();
476  }
477 
478  Type columnType = parseAndVerifyMatrixType(dialect, parser);
479  if (!columnType)
480  return Type();
481 
482  if (parser.parseGreater())
483  return Type();
484 
485  return MatrixType::get(columnType, columnCount);
486 }
487 
488 // Specialize this function to parse each of the parameters that define an
489 // ImageType. By default it assumes this is an enum type.
490 template <typename ValTy>
491 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
492  DialectAsmParser &parser) {
493  StringRef enumSpec;
494  SMLoc enumLoc = parser.getCurrentLocation();
495  if (parser.parseKeyword(&enumSpec)) {
496  return std::nullopt;
497  }
498 
499  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
500  if (!val)
501  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
502  return val;
503 }
504 
505 template <>
506 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
507  DialectAsmParser &parser) {
508  // TODO: Further verify that the element type can be sampled
509  auto ty = parseAndVerifyType(dialect, parser);
510  if (!ty)
511  return std::nullopt;
512  return ty;
513 }
514 
515 template <typename IntTy>
516 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
517  DialectAsmParser &parser) {
518  IntTy offsetVal = std::numeric_limits<IntTy>::max();
519  if (parser.parseInteger(offsetVal))
520  return std::nullopt;
521  return offsetVal;
522 }
523 
524 template <>
525 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
526  DialectAsmParser &parser) {
527  return parseAndVerifyInteger<unsigned>(dialect, parser);
528 }
529 
530 namespace {
531 // Functor object to parse a comma separated list of specs. The function
532 // parseAndVerify does the actual parsing and verification of individual
533 // elements. This is a functor since parsing the last element of the list
534 // (termination condition) needs partial specialization.
535 template <typename ParseType, typename... Args>
536 struct ParseCommaSeparatedList {
537  std::optional<std::tuple<ParseType, Args...>>
538  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
539  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
540  if (!parseVal)
541  return std::nullopt;
542 
543  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
544  if (numArgs != 0 && failed(parser.parseComma()))
545  return std::nullopt;
546  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
547  if (!remainingValues)
548  return std::nullopt;
549  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
550  remainingValues.value());
551  }
552 };
553 
554 // Partial specialization of the function to parse a comma separated list of
555 // specs to parse the last element of the list.
556 template <typename ParseType>
557 struct ParseCommaSeparatedList<ParseType> {
558  std::optional<std::tuple<ParseType>>
559  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
560  if (auto value = parseAndVerify<ParseType>(dialect, parser))
561  return std::tuple<ParseType>(*value);
562  return std::nullopt;
563  }
564 };
565 } // namespace
566 
567 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
568 //
569 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
570 //
571 // arrayed-info ::= `NonArrayed` | `Arrayed`
572 //
573 // sampling-info ::= `SingleSampled` | `MultiSampled`
574 //
575 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
576 //
577 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
578 //
579 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
580 // arrayed-info `,` sampling-info `,`
581 // sampler-use-info `,` format `>`
582 static Type parseImageType(SPIRVDialect const &dialect,
583  DialectAsmParser &parser) {
584  if (parser.parseLess())
585  return Type();
586 
587  auto value =
588  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
589  ImageSamplingInfo, ImageSamplerUseInfo,
590  ImageFormat>{}(dialect, parser);
591  if (!value)
592  return Type();
593 
594  if (parser.parseGreater())
595  return Type();
596  return ImageType::get(*value);
597 }
598 
599 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
600 static Type parseSampledImageType(SPIRVDialect const &dialect,
601  DialectAsmParser &parser) {
602  if (parser.parseLess())
603  return Type();
604 
605  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
606  if (!parsedType)
607  return Type();
608 
609  if (parser.parseGreater())
610  return Type();
611  return SampledImageType::get(parsedType);
612 }
613 
614 // Parse decorations associated with a member.
616  SPIRVDialect const &dialect, DialectAsmParser &parser,
617  ArrayRef<Type> memberTypes,
620 
621  // Check if the first element is offset.
622  SMLoc offsetLoc = parser.getCurrentLocation();
623  StructType::OffsetInfo offset = 0;
624  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
625  if (offsetParseResult.has_value()) {
626  if (failed(*offsetParseResult))
627  return failure();
628 
629  if (offsetInfo.size() != memberTypes.size() - 1) {
630  return parser.emitError(offsetLoc,
631  "offset specification must be given for "
632  "all members");
633  }
634  offsetInfo.push_back(offset);
635  }
636 
637  // Check for no spirv::Decorations.
638  if (succeeded(parser.parseOptionalRSquare()))
639  return success();
640 
641  // If there was an offset, make sure to parse the comma.
642  if (offsetParseResult.has_value() && parser.parseComma())
643  return failure();
644 
645  // Check for spirv::Decorations.
646  auto parseDecorations = [&]() {
647  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
648  if (!memberDecoration)
649  return failure();
650 
651  // Parse member decoration value if it exists.
652  if (succeeded(parser.parseOptionalEqual())) {
653  auto memberDecorationValue =
654  parseAndVerifyInteger<uint32_t>(dialect, parser);
655 
656  if (!memberDecorationValue)
657  return failure();
658 
659  memberDecorationInfo.emplace_back(
660  static_cast<uint32_t>(memberTypes.size() - 1), 1,
661  memberDecoration.value(), memberDecorationValue.value());
662  } else {
663  memberDecorationInfo.emplace_back(
664  static_cast<uint32_t>(memberTypes.size() - 1), 0,
665  memberDecoration.value(), 0);
666  }
667  return success();
668  };
669  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
670  failed(parser.parseRSquare()))
671  return failure();
672 
673  return success();
674 }
675 
676 // struct-member-decoration ::= integer-literal? spirv-decoration*
677 // struct-type ::=
678 // `!spirv.struct<` (id `,`)?
679 // `(`
680 // (spirv-type (`[` struct-member-decoration `]`)?)*
681 // `)>`
682 static Type parseStructType(SPIRVDialect const &dialect,
683  DialectAsmParser &parser) {
684  // TODO: This function is quite lengthy. Break it down into smaller chunks.
685 
686  if (parser.parseLess())
687  return Type();
688 
689  StringRef identifier;
691 
692  // Check if this is an identified struct type.
693  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
694  // Check if this is a possible recursive reference.
695  auto structType =
696  StructType::getIdentified(dialect.getContext(), identifier);
697  cyclicParse = parser.tryStartCyclicParse(structType);
698  if (succeeded(parser.parseOptionalGreater())) {
699  if (succeeded(cyclicParse)) {
700  parser.emitError(
701  parser.getNameLoc(),
702  "recursive struct reference not nested in struct definition");
703 
704  return Type();
705  }
706 
707  return structType;
708  }
709 
710  if (failed(parser.parseComma()))
711  return Type();
712 
713  if (failed(cyclicParse)) {
714  parser.emitError(parser.getNameLoc(),
715  "identifier already used for an enclosing struct");
716  return Type();
717  }
718  }
719 
720  if (failed(parser.parseLParen()))
721  return Type();
722 
723  if (succeeded(parser.parseOptionalRParen()) &&
724  succeeded(parser.parseOptionalGreater())) {
725  return StructType::getEmpty(dialect.getContext(), identifier);
726  }
727 
728  StructType idStructTy;
729 
730  if (!identifier.empty())
731  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
732 
733  SmallVector<Type, 4> memberTypes;
736 
737  do {
738  Type memberType;
739  if (parser.parseType(memberType))
740  return Type();
741  memberTypes.push_back(memberType);
742 
743  if (succeeded(parser.parseOptionalLSquare()))
744  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
745  memberDecorationInfo))
746  return Type();
747  } while (succeeded(parser.parseOptionalComma()));
748 
749  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
750  parser.emitError(parser.getNameLoc(),
751  "offset specification must be given for all members");
752  return Type();
753  }
754 
755  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
756  return Type();
757 
758  if (!identifier.empty()) {
759  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
760  memberDecorationInfo)))
761  return Type();
762  return idStructTy;
763  }
764 
765  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
766 }
767 
768 // spirv-type ::= array-type
769 // | element-type
770 // | image-type
771 // | pointer-type
772 // | runtime-array-type
773 // | sampled-image-type
774 // | struct-type
776  StringRef keyword;
777  if (parser.parseKeyword(&keyword))
778  return Type();
779 
780  if (keyword == "array")
781  return parseArrayType(*this, parser);
782  if (keyword == "coopmatrix")
783  return parseCooperativeMatrixType(*this, parser);
784  if (keyword == "jointmatrix")
785  return parseJointMatrixType(*this, parser);
786  if (keyword == "image")
787  return parseImageType(*this, parser);
788  if (keyword == "ptr")
789  return parsePointerType(*this, parser);
790  if (keyword == "rtarray")
791  return parseRuntimeArrayType(*this, parser);
792  if (keyword == "sampled_image")
793  return parseSampledImageType(*this, parser);
794  if (keyword == "struct")
795  return parseStructType(*this, parser);
796  if (keyword == "matrix")
797  return parseMatrixType(*this, parser);
798  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
799  return Type();
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // Type Printing
804 //===----------------------------------------------------------------------===//
805 
806 static void print(ArrayType type, DialectAsmPrinter &os) {
807  os << "array<" << type.getNumElements() << " x " << type.getElementType();
808  if (unsigned stride = type.getArrayStride())
809  os << ", stride=" << stride;
810  os << ">";
811 }
812 
813 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
814  os << "rtarray<" << type.getElementType();
815  if (unsigned stride = type.getArrayStride())
816  os << ", stride=" << stride;
817  os << ">";
818 }
819 
820 static void print(PointerType type, DialectAsmPrinter &os) {
821  os << "ptr<" << type.getPointeeType() << ", "
822  << stringifyStorageClass(type.getStorageClass()) << ">";
823 }
824 
825 static void print(ImageType type, DialectAsmPrinter &os) {
826  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
827  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
828  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
829  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
830  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
831  << stringifyImageFormat(type.getImageFormat()) << ">";
832 }
833 
834 static void print(SampledImageType type, DialectAsmPrinter &os) {
835  os << "sampled_image<" << type.getImageType() << ">";
836 }
837 
838 static void print(StructType type, DialectAsmPrinter &os) {
840 
841  os << "struct<";
842 
843  if (type.isIdentified()) {
844  os << type.getIdentifier();
845 
846  cyclicPrint = os.tryStartCyclicPrint(type);
847  if (failed(cyclicPrint)) {
848  os << ">";
849  return;
850  }
851 
852  os << ", ";
853  }
854 
855  os << "(";
856 
857  auto printMember = [&](unsigned i) {
858  os << type.getElementType(i);
860  type.getMemberDecorations(i, decorations);
861  if (type.hasOffset() || !decorations.empty()) {
862  os << " [";
863  if (type.hasOffset()) {
864  os << type.getMemberOffset(i);
865  if (!decorations.empty())
866  os << ", ";
867  }
868  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
869  os << stringifyDecoration(decoration.decoration);
870  if (decoration.hasValue) {
871  os << "=" << decoration.decorationValue;
872  }
873  };
874  llvm::interleaveComma(decorations, os, eachFn);
875  os << "]";
876  }
877  };
878  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
879  printMember);
880  os << ")>";
881 }
882 
884  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
885  << type.getElementType() << ", " << type.getScope() << ", "
886  << type.getUse() << ">";
887 }
888 
890  os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
891  os << type.getElementType() << ", "
892  << stringifyMatrixLayout(type.getMatrixLayout());
893  os << ", " << stringifyScope(type.getScope()) << ">";
894 }
895 
896 static void print(MatrixType type, DialectAsmPrinter &os) {
897  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
898  os << ">";
899 }
900 
901 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
902  TypeSwitch<Type>(type)
905  MatrixType>([&](auto type) { print(type, os); })
906  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // Constant
911 //===----------------------------------------------------------------------===//
912 
914  Attribute value, Type type,
915  Location loc) {
916  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
917  return builder.create<ub::PoisonOp>(loc, type, poison);
918 
919  if (!spirv::ConstantOp::isBuildableWith(type))
920  return nullptr;
921 
922  return builder.create<spirv::ConstantOp>(loc, type, value);
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Shader Interface ABI
927 //===----------------------------------------------------------------------===//
928 
929 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
930  NamedAttribute attribute) {
931  StringRef symbol = attribute.getName().strref();
932  Attribute attr = attribute.getValue();
933 
934  if (symbol == spirv::getEntryPointABIAttrName()) {
935  if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
936  return op->emitError("'")
937  << symbol << "' attribute must be an entry point ABI attribute";
938  }
939  } else if (symbol == spirv::getTargetEnvAttrName()) {
940  if (!llvm::isa<spirv::TargetEnvAttr>(attr))
941  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
942  } else {
943  return op->emitError("found unsupported '")
944  << symbol << "' attribute on operation";
945  }
946 
947  return success();
948 }
949 
950 /// Verifies the given SPIR-V `attribute` attached to a value of the given
951 /// `valueType` is valid.
953  NamedAttribute attribute) {
954  StringRef symbol = attribute.getName().strref();
955  Attribute attr = attribute.getValue();
956 
957  if (symbol == spirv::getInterfaceVarABIAttrName()) {
958  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
959  if (!varABIAttr)
960  return emitError(loc, "'")
961  << symbol << "' must be a spirv::InterfaceVarABIAttr";
962 
963  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
964  return emitError(loc, "'") << symbol
965  << "' attribute cannot specify storage class "
966  "when attaching to a non-scalar value";
967  return success();
968  }
969  if (symbol == spirv::DecorationAttr::name) {
970  if (!isa<spirv::DecorationAttr>(attr))
971  return emitError(loc, "'")
972  << symbol << "' must be a spirv::DecorationAttr";
973  return success();
974  }
975 
976  return emitError(loc, "found unsupported '")
977  << symbol << "' attribute on region argument";
978 }
979 
980 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
981  unsigned regionIndex,
982  unsigned argIndex,
983  NamedAttribute attribute) {
984  auto funcOp = dyn_cast<FunctionOpInterface>(op);
985  if (!funcOp)
986  return success();
987  Type argType = funcOp.getArgumentTypes()[argIndex];
988 
989  return verifyRegionAttribute(op->getLoc(), argType, attribute);
990 }
991 
992 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
993  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
994  NamedAttribute attribute) {
995  return op->emitError("cannot attach SPIR-V attributes to region result");
996 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseJointMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:216
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:246
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:240
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:242
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:228
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:248
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:169
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:424
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:426
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:438
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:434
Type getElementType() const
Definition: SPIRVTypes.cpp:420
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:430
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:309
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:313
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:298
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:311
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:315
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:485
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:487
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:807
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:582
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26