MLIR  20.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::spirv;
38 
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48  return llvm::any_of(region, [](Block &block) {
49  Operation *terminator = block.getTerminator();
50  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51  });
52 }
53 
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
58 
59  /// All call operations within SPIRV can be inlined.
60  bool isLegalToInline(Operation *call, Operation *callable,
61  bool wouldBeCloned) const final {
62  return true;
63  }
64 
65  /// Returns true if the given region 'src' can be inlined into the region
66  /// 'dest' that is attached to an operation registered to the current dialect.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &) const final {
69  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70  // spirv.mlir.loop operations.
71  auto *op = dest->getParentOp();
72  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73  }
74 
75  /// Returns true if the given operation 'op', that is registered to this
76  /// dialect, can be inlined into the region 'dest' that is attached to an
77  /// operation registered to the current dialect.
78  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79  IRMapping &) const final {
80  // TODO: Enable inlining structured control flows with return.
81  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82  containsReturn(op->getRegion(0)))
83  return false;
84  // TODO: we need to filter OpKill here to avoid inlining it to
85  // a loop continue construct:
86  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87  // However OpKill is fragment shader specific and we don't support it yet.
88  return true;
89  }
90 
91  /// Handle the given inlined terminator by replacing it with a new operation
92  /// as necessary.
93  void handleTerminator(Operation *op, Block *newDest) const final {
94  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
96  op->erase();
97  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
98  OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
99  retValOp->getOperands());
100  op->erase();
101  }
102  }
103 
104  /// Handle the given inlined terminator by replacing it with a new operation
105  /// as necessary.
106  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107  // Only spirv.ReturnValue needs to be handled here.
108  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
109  if (!retValOp)
110  return;
111 
112  // Replace the values directly with the return operands.
113  assert(valuesToRepl.size() == 1 &&
114  "spirv.ReturnValue expected to only handle one result");
115  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
116  }
117 };
118 } // namespace
119 
120 //===----------------------------------------------------------------------===//
121 // SPIR-V Dialect
122 //===----------------------------------------------------------------------===//
123 
124 void SPIRVDialect::initialize() {
125  registerAttributes();
126  registerTypes();
127 
128  // Add SPIR-V ops.
129  addOperations<
130 #define GET_OP_LIST
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132  >();
133 
134  addInterfaces<SPIRVInlinerInterface>();
135 
136  // Allow unknown operations because SPIR-V is extensible.
137  allowUnknownOperations();
138  declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
139 }
140 
141 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // Type Parsing
147 //===----------------------------------------------------------------------===//
148 
149 // Forward declarations.
150 template <typename ValTy>
151 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152  DialectAsmParser &parser);
153 template <>
154 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155  DialectAsmParser &parser);
156 
157 template <>
158 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159  DialectAsmParser &parser);
160 
161 static Type parseAndVerifyType(SPIRVDialect const &dialect,
162  DialectAsmParser &parser) {
163  Type type;
164  SMLoc typeLoc = parser.getCurrentLocation();
165  if (parser.parseType(type))
166  return Type();
167 
168  // Allow SPIR-V dialect types
169  if (&type.getDialect() == &dialect)
170  return type;
171 
172  // Check other allowed types
173  if (auto t = llvm::dyn_cast<FloatType>(type)) {
174  if (type.isBF16()) {
175  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
176  return Type();
177  }
178  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
179  if (!ScalarType::isValid(t)) {
180  parser.emitError(typeLoc,
181  "only 1/8/16/32/64-bit integer type allowed but found ")
182  << type;
183  return Type();
184  }
185  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
186  if (t.getRank() != 1) {
187  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
188  return Type();
189  }
190  if (t.getNumElements() > 4) {
191  parser.emitError(
192  typeLoc, "vector length has to be less than or equal to 4 but found ")
193  << t.getNumElements();
194  return Type();
195  }
196  } else {
197  parser.emitError(typeLoc, "cannot use ")
198  << type << " to compose SPIR-V types";
199  return Type();
200  }
201 
202  return type;
203 }
204 
205 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
206  DialectAsmParser &parser) {
207  Type type;
208  SMLoc typeLoc = parser.getCurrentLocation();
209  if (parser.parseType(type))
210  return Type();
211 
212  if (auto t = llvm::dyn_cast<VectorType>(type)) {
213  if (t.getRank() != 1) {
214  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215  return Type();
216  }
217  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
218  parser.emitError(typeLoc,
219  "matrix columns size has to be less than or equal "
220  "to 4 and greater than or equal 2, but found ")
221  << t.getNumElements();
222  return Type();
223  }
224 
225  if (!llvm::isa<FloatType>(t.getElementType())) {
226  parser.emitError(typeLoc, "matrix columns' elements must be of "
227  "Float type, got ")
228  << t.getElementType();
229  return Type();
230  }
231  } else {
232  parser.emitError(typeLoc, "matrix must be composed using vector "
233  "type, got ")
234  << type;
235  return Type();
236  }
237 
238  return type;
239 }
240 
241 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
242  DialectAsmParser &parser) {
243  Type type;
244  SMLoc typeLoc = parser.getCurrentLocation();
245  if (parser.parseType(type))
246  return Type();
247 
248  if (!llvm::isa<ImageType>(type)) {
249  parser.emitError(typeLoc,
250  "sampled image must be composed using image type, got ")
251  << type;
252  return Type();
253  }
254 
255  return type;
256 }
257 
258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
260 /// missing.
261 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
262  DialectAsmParser &parser,
263  unsigned &stride) {
264  if (failed(parser.parseOptionalComma())) {
265  stride = 0;
266  return success();
267  }
268 
269  if (parser.parseKeyword("stride") || parser.parseEqual())
270  return failure();
271 
272  SMLoc strideLoc = parser.getCurrentLocation();
273  std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
274  if (!optStride)
275  return failure();
276 
277  if (!(stride = *optStride)) {
278  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
279  return failure();
280  }
281  return success();
282 }
283 
284 // element-type ::= integer-type
285 // | floating-point-type
286 // | vector-type
287 // | spirv-type
288 //
289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290 // (`,` `stride` `=` integer-literal)? `>`
291 static Type parseArrayType(SPIRVDialect const &dialect,
292  DialectAsmParser &parser) {
293  if (parser.parseLess())
294  return Type();
295 
296  SmallVector<int64_t, 1> countDims;
297  SMLoc countLoc = parser.getCurrentLocation();
298  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
299  return Type();
300  if (countDims.size() != 1) {
301  parser.emitError(countLoc,
302  "expected single integer for array element count");
303  return Type();
304  }
305 
306  // According to the SPIR-V spec:
307  // "Length is the number of elements in the array. It must be at least 1."
308  int64_t count = countDims[0];
309  if (count == 0) {
310  parser.emitError(countLoc, "expected array length greater than 0");
311  return Type();
312  }
313 
314  Type elementType = parseAndVerifyType(dialect, parser);
315  if (!elementType)
316  return Type();
317 
318  unsigned stride = 0;
319  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
320  return Type();
321 
322  if (parser.parseGreater())
323  return Type();
324  return ArrayType::get(elementType, count, stride);
325 }
326 
327 // cooperative-matrix-type ::=
328 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
329 // scope `,` use `>`
330 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
331  DialectAsmParser &parser) {
332  if (parser.parseLess())
333  return {};
334 
336  SMLoc countLoc = parser.getCurrentLocation();
337  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
338  return {};
339 
340  if (dims.size() != 2) {
341  parser.emitError(countLoc, "expected row and column count");
342  return {};
343  }
344 
345  auto elementTy = parseAndVerifyType(dialect, parser);
346  if (!elementTy)
347  return {};
348 
349  Scope scope;
350  if (parser.parseComma() ||
351  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
352  return {};
353 
354  CooperativeMatrixUseKHR use;
355  if (parser.parseComma() ||
356  spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
357  return {};
358 
359  if (parser.parseGreater())
360  return {};
361 
362  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
363 }
364 
365 // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
366 // element-type
367 // `,` layout `,` scope`>`
368 static Type parseJointMatrixType(SPIRVDialect const &dialect,
369  DialectAsmParser &parser) {
370  if (parser.parseLess())
371  return Type();
372 
374  SMLoc countLoc = parser.getCurrentLocation();
375  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
376  return Type();
377 
378  if (dims.size() != 2) {
379  parser.emitError(countLoc, "expected rows and columns size");
380  return Type();
381  }
382 
383  auto elementTy = parseAndVerifyType(dialect, parser);
384  if (!elementTy)
385  return Type();
386  MatrixLayout matrixLayout;
387  if (parser.parseComma() ||
388  spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
389  return Type();
390  Scope scope;
391  if (parser.parseComma() ||
392  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
393  return Type();
394  if (parser.parseGreater())
395  return Type();
396  return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
397  matrixLayout);
398 }
399 
400 // TODO: Reorder methods to be utilities first and parse*Type
401 // methods in alphabetical order
402 //
403 // storage-class ::= `UniformConstant`
404 // | `Uniform`
405 // | `Workgroup`
406 // | <and other storage classes...>
407 //
408 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
409 static Type parsePointerType(SPIRVDialect const &dialect,
410  DialectAsmParser &parser) {
411  if (parser.parseLess())
412  return Type();
413 
414  auto pointeeType = parseAndVerifyType(dialect, parser);
415  if (!pointeeType)
416  return Type();
417 
418  StringRef storageClassSpec;
419  SMLoc storageClassLoc = parser.getCurrentLocation();
420  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
421  return Type();
422 
423  auto storageClass = symbolizeStorageClass(storageClassSpec);
424  if (!storageClass) {
425  parser.emitError(storageClassLoc, "unknown storage class: ")
426  << storageClassSpec;
427  return Type();
428  }
429  if (parser.parseGreater())
430  return Type();
431  return PointerType::get(pointeeType, *storageClass);
432 }
433 
434 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
435 // (`,` `stride` `=` integer-literal)? `>`
436 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
437  DialectAsmParser &parser) {
438  if (parser.parseLess())
439  return Type();
440 
441  Type elementType = parseAndVerifyType(dialect, parser);
442  if (!elementType)
443  return Type();
444 
445  unsigned stride = 0;
446  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
447  return Type();
448 
449  if (parser.parseGreater())
450  return Type();
451  return RuntimeArrayType::get(elementType, stride);
452 }
453 
454 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
455 static Type parseMatrixType(SPIRVDialect const &dialect,
456  DialectAsmParser &parser) {
457  if (parser.parseLess())
458  return Type();
459 
460  SmallVector<int64_t, 1> countDims;
461  SMLoc countLoc = parser.getCurrentLocation();
462  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
463  return Type();
464  if (countDims.size() != 1) {
465  parser.emitError(countLoc, "expected single unsigned "
466  "integer for number of columns");
467  return Type();
468  }
469 
470  int64_t columnCount = countDims[0];
471  // According to the specification, Matrices can have 2, 3, or 4 columns
472  if (columnCount < 2 || columnCount > 4) {
473  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
474  "columns");
475  return Type();
476  }
477 
478  Type columnType = parseAndVerifyMatrixType(dialect, parser);
479  if (!columnType)
480  return Type();
481 
482  if (parser.parseGreater())
483  return Type();
484 
485  return MatrixType::get(columnType, columnCount);
486 }
487 
488 // Specialize this function to parse each of the parameters that define an
489 // ImageType. By default it assumes this is an enum type.
490 template <typename ValTy>
491 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
492  DialectAsmParser &parser) {
493  StringRef enumSpec;
494  SMLoc enumLoc = parser.getCurrentLocation();
495  if (parser.parseKeyword(&enumSpec)) {
496  return std::nullopt;
497  }
498 
499  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
500  if (!val)
501  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
502  return val;
503 }
504 
505 template <>
506 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
507  DialectAsmParser &parser) {
508  // TODO: Further verify that the element type can be sampled
509  auto ty = parseAndVerifyType(dialect, parser);
510  if (!ty)
511  return std::nullopt;
512  return ty;
513 }
514 
515 template <typename IntTy>
516 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
517  DialectAsmParser &parser) {
518  IntTy offsetVal = std::numeric_limits<IntTy>::max();
519  if (parser.parseInteger(offsetVal))
520  return std::nullopt;
521  return offsetVal;
522 }
523 
524 template <>
525 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
526  DialectAsmParser &parser) {
527  return parseAndVerifyInteger<unsigned>(dialect, parser);
528 }
529 
530 namespace {
531 // Functor object to parse a comma separated list of specs. The function
532 // parseAndVerify does the actual parsing and verification of individual
533 // elements. This is a functor since parsing the last element of the list
534 // (termination condition) needs partial specialization.
535 template <typename ParseType, typename... Args>
536 struct ParseCommaSeparatedList {
537  std::optional<std::tuple<ParseType, Args...>>
538  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
539  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
540  if (!parseVal)
541  return std::nullopt;
542 
543  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
544  if (numArgs != 0 && failed(parser.parseComma()))
545  return std::nullopt;
546  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
547  if (!remainingValues)
548  return std::nullopt;
549  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
550  remainingValues.value());
551  }
552 };
553 
554 // Partial specialization of the function to parse a comma separated list of
555 // specs to parse the last element of the list.
556 template <typename ParseType>
557 struct ParseCommaSeparatedList<ParseType> {
558  std::optional<std::tuple<ParseType>>
559  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
560  if (auto value = parseAndVerify<ParseType>(dialect, parser))
561  return std::tuple<ParseType>(*value);
562  return std::nullopt;
563  }
564 };
565 } // namespace
566 
567 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
568 //
569 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
570 //
571 // arrayed-info ::= `NonArrayed` | `Arrayed`
572 //
573 // sampling-info ::= `SingleSampled` | `MultiSampled`
574 //
575 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
576 //
577 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
578 //
579 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
580 // arrayed-info `,` sampling-info `,`
581 // sampler-use-info `,` format `>`
582 static Type parseImageType(SPIRVDialect const &dialect,
583  DialectAsmParser &parser) {
584  if (parser.parseLess())
585  return Type();
586 
587  auto value =
588  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
589  ImageSamplingInfo, ImageSamplerUseInfo,
590  ImageFormat>{}(dialect, parser);
591  if (!value)
592  return Type();
593 
594  if (parser.parseGreater())
595  return Type();
596  return ImageType::get(*value);
597 }
598 
599 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
600 static Type parseSampledImageType(SPIRVDialect const &dialect,
601  DialectAsmParser &parser) {
602  if (parser.parseLess())
603  return Type();
604 
605  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
606  if (!parsedType)
607  return Type();
608 
609  if (parser.parseGreater())
610  return Type();
611  return SampledImageType::get(parsedType);
612 }
613 
614 // Parse decorations associated with a member.
615 static ParseResult parseStructMemberDecorations(
616  SPIRVDialect const &dialect, DialectAsmParser &parser,
617  ArrayRef<Type> memberTypes,
620 
621  // Check if the first element is offset.
622  SMLoc offsetLoc = parser.getCurrentLocation();
623  StructType::OffsetInfo offset = 0;
624  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
625  if (offsetParseResult.has_value()) {
626  if (failed(*offsetParseResult))
627  return failure();
628 
629  if (offsetInfo.size() != memberTypes.size() - 1) {
630  return parser.emitError(offsetLoc,
631  "offset specification must be given for "
632  "all members");
633  }
634  offsetInfo.push_back(offset);
635  }
636 
637  // Check for no spirv::Decorations.
638  if (succeeded(parser.parseOptionalRSquare()))
639  return success();
640 
641  // If there was an offset, make sure to parse the comma.
642  if (offsetParseResult.has_value() && parser.parseComma())
643  return failure();
644 
645  // Check for spirv::Decorations.
646  auto parseDecorations = [&]() {
647  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
648  if (!memberDecoration)
649  return failure();
650 
651  // Parse member decoration value if it exists.
652  if (succeeded(parser.parseOptionalEqual())) {
653  auto memberDecorationValue =
654  parseAndVerifyInteger<uint32_t>(dialect, parser);
655 
656  if (!memberDecorationValue)
657  return failure();
658 
659  memberDecorationInfo.emplace_back(
660  static_cast<uint32_t>(memberTypes.size() - 1), 1,
661  memberDecoration.value(), memberDecorationValue.value());
662  } else {
663  memberDecorationInfo.emplace_back(
664  static_cast<uint32_t>(memberTypes.size() - 1), 0,
665  memberDecoration.value(), 0);
666  }
667  return success();
668  };
669  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
670  failed(parser.parseRSquare()))
671  return failure();
672 
673  return success();
674 }
675 
676 // struct-member-decoration ::= integer-literal? spirv-decoration*
677 // struct-type ::=
678 // `!spirv.struct<` (id `,`)?
679 // `(`
680 // (spirv-type (`[` struct-member-decoration `]`)?)*
681 // `)>`
682 static Type parseStructType(SPIRVDialect const &dialect,
683  DialectAsmParser &parser) {
684  // TODO: This function is quite lengthy. Break it down into smaller chunks.
685 
686  if (parser.parseLess())
687  return Type();
688 
689  StringRef identifier;
690  FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
691 
692  // Check if this is an identified struct type.
693  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
694  // Check if this is a possible recursive reference.
695  auto structType =
696  StructType::getIdentified(dialect.getContext(), identifier);
697  cyclicParse = parser.tryStartCyclicParse(structType);
698  if (succeeded(parser.parseOptionalGreater())) {
699  if (succeeded(cyclicParse)) {
700  parser.emitError(
701  parser.getNameLoc(),
702  "recursive struct reference not nested in struct definition");
703 
704  return Type();
705  }
706 
707  return structType;
708  }
709 
710  if (failed(parser.parseComma()))
711  return Type();
712 
713  if (failed(cyclicParse)) {
714  parser.emitError(parser.getNameLoc(),
715  "identifier already used for an enclosing struct");
716  return Type();
717  }
718  }
719 
720  if (failed(parser.parseLParen()))
721  return Type();
722 
723  if (succeeded(parser.parseOptionalRParen()) &&
724  succeeded(parser.parseOptionalGreater())) {
725  return StructType::getEmpty(dialect.getContext(), identifier);
726  }
727 
728  StructType idStructTy;
729 
730  if (!identifier.empty())
731  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
732 
733  SmallVector<Type, 4> memberTypes;
736 
737  do {
738  Type memberType;
739  if (parser.parseType(memberType))
740  return Type();
741  memberTypes.push_back(memberType);
742 
743  if (succeeded(parser.parseOptionalLSquare()))
744  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
745  memberDecorationInfo))
746  return Type();
747  } while (succeeded(parser.parseOptionalComma()));
748 
749  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
750  parser.emitError(parser.getNameLoc(),
751  "offset specification must be given for all members");
752  return Type();
753  }
754 
755  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
756  return Type();
757 
758  if (!identifier.empty()) {
759  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
760  memberDecorationInfo)))
761  return Type();
762  return idStructTy;
763  }
764 
765  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
766 }
767 
768 // spirv-type ::= array-type
769 // | element-type
770 // | image-type
771 // | pointer-type
772 // | runtime-array-type
773 // | sampled-image-type
774 // | struct-type
776  StringRef keyword;
777  if (parser.parseKeyword(&keyword))
778  return Type();
779 
780  if (keyword == "array")
781  return parseArrayType(*this, parser);
782  if (keyword == "coopmatrix")
783  return parseCooperativeMatrixType(*this, parser);
784  if (keyword == "jointmatrix")
785  return parseJointMatrixType(*this, parser);
786  if (keyword == "image")
787  return parseImageType(*this, parser);
788  if (keyword == "ptr")
789  return parsePointerType(*this, parser);
790  if (keyword == "rtarray")
791  return parseRuntimeArrayType(*this, parser);
792  if (keyword == "sampled_image")
793  return parseSampledImageType(*this, parser);
794  if (keyword == "struct")
795  return parseStructType(*this, parser);
796  if (keyword == "matrix")
797  return parseMatrixType(*this, parser);
798  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
799  return Type();
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // Type Printing
804 //===----------------------------------------------------------------------===//
805 
806 static void print(ArrayType type, DialectAsmPrinter &os) {
807  os << "array<" << type.getNumElements() << " x " << type.getElementType();
808  if (unsigned stride = type.getArrayStride())
809  os << ", stride=" << stride;
810  os << ">";
811 }
812 
813 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
814  os << "rtarray<" << type.getElementType();
815  if (unsigned stride = type.getArrayStride())
816  os << ", stride=" << stride;
817  os << ">";
818 }
819 
820 static void print(PointerType type, DialectAsmPrinter &os) {
821  os << "ptr<" << type.getPointeeType() << ", "
822  << stringifyStorageClass(type.getStorageClass()) << ">";
823 }
824 
825 static void print(ImageType type, DialectAsmPrinter &os) {
826  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
827  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
828  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
829  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
830  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
831  << stringifyImageFormat(type.getImageFormat()) << ">";
832 }
833 
834 static void print(SampledImageType type, DialectAsmPrinter &os) {
835  os << "sampled_image<" << type.getImageType() << ">";
836 }
837 
838 static void print(StructType type, DialectAsmPrinter &os) {
839  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
840 
841  os << "struct<";
842 
843  if (type.isIdentified()) {
844  os << type.getIdentifier();
845 
846  cyclicPrint = os.tryStartCyclicPrint(type);
847  if (failed(cyclicPrint)) {
848  os << ">";
849  return;
850  }
851 
852  os << ", ";
853  }
854 
855  os << "(";
856 
857  auto printMember = [&](unsigned i) {
858  os << type.getElementType(i);
860  type.getMemberDecorations(i, decorations);
861  if (type.hasOffset() || !decorations.empty()) {
862  os << " [";
863  if (type.hasOffset()) {
864  os << type.getMemberOffset(i);
865  if (!decorations.empty())
866  os << ", ";
867  }
868  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
869  os << stringifyDecoration(decoration.decoration);
870  if (decoration.hasValue) {
871  os << "=" << decoration.decorationValue;
872  }
873  };
874  llvm::interleaveComma(decorations, os, eachFn);
875  os << "]";
876  }
877  };
878  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
879  printMember);
880  os << ")>";
881 }
882 
884  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
885  << type.getElementType() << ", " << type.getScope() << ", "
886  << type.getUse() << ">";
887 }
888 
890  os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
891  os << type.getElementType() << ", "
892  << stringifyMatrixLayout(type.getMatrixLayout());
893  os << ", " << stringifyScope(type.getScope()) << ">";
894 }
895 
896 static void print(MatrixType type, DialectAsmPrinter &os) {
897  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
898  os << ">";
899 }
900 
901 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
902  TypeSwitch<Type>(type)
905  MatrixType>([&](auto type) { print(type, os); })
906  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // Constant
911 //===----------------------------------------------------------------------===//
912 
914  Attribute value, Type type,
915  Location loc) {
916  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
917  return builder.create<ub::PoisonOp>(loc, type, poison);
918 
919  if (!spirv::ConstantOp::isBuildableWith(type))
920  return nullptr;
921 
922  return builder.create<spirv::ConstantOp>(loc, type, value);
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Shader Interface ABI
927 //===----------------------------------------------------------------------===//
928 
929 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
930  NamedAttribute attribute) {
931  StringRef symbol = attribute.getName().strref();
932  Attribute attr = attribute.getValue();
933 
934  if (symbol == spirv::getEntryPointABIAttrName()) {
935  if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
936  return op->emitError("'")
937  << symbol << "' attribute must be an entry point ABI attribute";
938  }
939  } else if (symbol == spirv::getTargetEnvAttrName()) {
940  if (!llvm::isa<spirv::TargetEnvAttr>(attr))
941  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
942  } else {
943  return op->emitError("found unsupported '")
944  << symbol << "' attribute on operation";
945  }
946 
947  return success();
948 }
949 
950 /// Verifies the given SPIR-V `attribute` attached to a value of the given
951 /// `valueType` is valid.
952 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
953  NamedAttribute attribute) {
954  StringRef symbol = attribute.getName().strref();
955  Attribute attr = attribute.getValue();
956 
957  if (symbol == spirv::getInterfaceVarABIAttrName()) {
958  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
959  if (!varABIAttr)
960  return emitError(loc, "'")
961  << symbol << "' must be a spirv::InterfaceVarABIAttr";
962 
963  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
964  return emitError(loc, "'") << symbol
965  << "' attribute cannot specify storage class "
966  "when attaching to a non-scalar value";
967  return success();
968  }
969  if (symbol == spirv::DecorationAttr::name) {
970  if (!isa<spirv::DecorationAttr>(attr))
971  return emitError(loc, "'")
972  << symbol << "' must be a spirv::DecorationAttr";
973  return success();
974  }
975 
976  return emitError(loc, "found unsupported '")
977  << symbol << "' attribute on region argument";
978 }
979 
980 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
981  unsigned regionIndex,
982  unsigned argIndex,
983  NamedAttribute attribute) {
984  auto funcOp = dyn_cast<FunctionOpInterface>(op);
985  if (!funcOp)
986  return success();
987  Type argType = funcOp.getArgumentTypes()[argIndex];
988 
989  return verifyRegionAttribute(op->getLoc(), argType, attribute);
990 }
991 
992 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
993  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
994  NamedAttribute attribute) {
995  return op->emitError("cannot attach SPIR-V attributes to region result");
996 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseJointMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:124
bool isBF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:246
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:240
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:242
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:228
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:248
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:169
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:424
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:426
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:438
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:434
Type getElementType() const
Definition: SPIRVTypes.cpp:420
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:430
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:309
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:313
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:298
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:311
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:315
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:485
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:487
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:807
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:582
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.