MLIR  21.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::spirv;
38 
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
44 
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48  return llvm::any_of(region, [](Block &block) {
49  Operation *terminator = block.getTerminator();
50  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51  });
52 }
53 
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
58 
59  /// All call operations within SPIRV can be inlined.
60  bool isLegalToInline(Operation *call, Operation *callable,
61  bool wouldBeCloned) const final {
62  return true;
63  }
64 
65  /// Returns true if the given region 'src' can be inlined into the region
66  /// 'dest' that is attached to an operation registered to the current dialect.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &) const final {
69  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70  // spirv.mlir.loop operations.
71  auto *op = dest->getParentOp();
72  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
73  }
74 
75  /// Returns true if the given operation 'op', that is registered to this
76  /// dialect, can be inlined into the region 'dest' that is attached to an
77  /// operation registered to the current dialect.
78  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79  IRMapping &) const final {
80  // TODO: Enable inlining structured control flows with return.
81  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82  containsReturn(op->getRegion(0)))
83  return false;
84  // TODO: we need to filter OpKill here to avoid inlining it to
85  // a loop continue construct:
86  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87  // For now, we just disallow inlining OpKill anywhere in the code,
88  // but this restriction should be relaxed, as pointed above.
89  if (isa<spirv::KillOp>(op))
90  return false;
91 
92  return true;
93  }
94 
95  /// Handle the given inlined terminator by replacing it with a new operation
96  /// as necessary.
97  void handleTerminator(Operation *op, Block *newDest) const final {
98  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
99  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
100  op->erase();
101  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
102  OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
103  retValOp->getOperands());
104  op->erase();
105  }
106  }
107 
108  /// Handle the given inlined terminator by replacing it with a new operation
109  /// as necessary.
110  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
111  // Only spirv.ReturnValue needs to be handled here.
112  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
113  if (!retValOp)
114  return;
115 
116  // Replace the values directly with the return operands.
117  assert(valuesToRepl.size() == 1 &&
118  "spirv.ReturnValue expected to only handle one result");
119  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
120  }
121 };
122 } // namespace
123 
124 //===----------------------------------------------------------------------===//
125 // SPIR-V Dialect
126 //===----------------------------------------------------------------------===//
127 
128 void SPIRVDialect::initialize() {
129  registerAttributes();
130  registerTypes();
131 
132  // Add SPIR-V ops.
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
136  >();
137 
138  addInterfaces<SPIRVInlinerInterface>();
139 
140  // Allow unknown operations because SPIR-V is extensible.
141  allowUnknownOperations();
142  declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
143 }
144 
145 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
146  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // Type Parsing
151 //===----------------------------------------------------------------------===//
152 
153 // Forward declarations.
154 template <typename ValTy>
155 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
156  DialectAsmParser &parser);
157 template <>
158 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
159  DialectAsmParser &parser);
160 
161 template <>
162 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
163  DialectAsmParser &parser);
164 
165 static Type parseAndVerifyType(SPIRVDialect const &dialect,
166  DialectAsmParser &parser) {
167  Type type;
168  SMLoc typeLoc = parser.getCurrentLocation();
169  if (parser.parseType(type))
170  return Type();
171 
172  // Allow SPIR-V dialect types
173  if (&type.getDialect() == &dialect)
174  return type;
175 
176  // Check other allowed types
177  if (auto t = llvm::dyn_cast<FloatType>(type)) {
178  if (type.isBF16()) {
179  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
180  return Type();
181  }
182  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
183  if (!ScalarType::isValid(t)) {
184  parser.emitError(typeLoc,
185  "only 1/8/16/32/64-bit integer type allowed but found ")
186  << type;
187  return Type();
188  }
189  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
190  if (t.getRank() != 1) {
191  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
192  return Type();
193  }
194  if (t.getNumElements() > 4) {
195  parser.emitError(
196  typeLoc, "vector length has to be less than or equal to 4 but found ")
197  << t.getNumElements();
198  return Type();
199  }
200  } else {
201  parser.emitError(typeLoc, "cannot use ")
202  << type << " to compose SPIR-V types";
203  return Type();
204  }
205 
206  return type;
207 }
208 
209 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
210  DialectAsmParser &parser) {
211  Type type;
212  SMLoc typeLoc = parser.getCurrentLocation();
213  if (parser.parseType(type))
214  return Type();
215 
216  if (auto t = llvm::dyn_cast<VectorType>(type)) {
217  if (t.getRank() != 1) {
218  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
219  return Type();
220  }
221  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
222  parser.emitError(typeLoc,
223  "matrix columns size has to be less than or equal "
224  "to 4 and greater than or equal 2, but found ")
225  << t.getNumElements();
226  return Type();
227  }
228 
229  if (!llvm::isa<FloatType>(t.getElementType())) {
230  parser.emitError(typeLoc, "matrix columns' elements must be of "
231  "Float type, got ")
232  << t.getElementType();
233  return Type();
234  }
235  } else {
236  parser.emitError(typeLoc, "matrix must be composed using vector "
237  "type, got ")
238  << type;
239  return Type();
240  }
241 
242  return type;
243 }
244 
245 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
246  DialectAsmParser &parser) {
247  Type type;
248  SMLoc typeLoc = parser.getCurrentLocation();
249  if (parser.parseType(type))
250  return Type();
251 
252  if (!llvm::isa<ImageType>(type)) {
253  parser.emitError(typeLoc,
254  "sampled image must be composed using image type, got ")
255  << type;
256  return Type();
257  }
258 
259  return type;
260 }
261 
262 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
263 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
264 /// missing.
265 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
266  DialectAsmParser &parser,
267  unsigned &stride) {
268  if (failed(parser.parseOptionalComma())) {
269  stride = 0;
270  return success();
271  }
272 
273  if (parser.parseKeyword("stride") || parser.parseEqual())
274  return failure();
275 
276  SMLoc strideLoc = parser.getCurrentLocation();
277  std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
278  if (!optStride)
279  return failure();
280 
281  if (!(stride = *optStride)) {
282  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
283  return failure();
284  }
285  return success();
286 }
287 
288 // element-type ::= integer-type
289 // | floating-point-type
290 // | vector-type
291 // | spirv-type
292 //
293 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
294 // (`,` `stride` `=` integer-literal)? `>`
295 static Type parseArrayType(SPIRVDialect const &dialect,
296  DialectAsmParser &parser) {
297  if (parser.parseLess())
298  return Type();
299 
300  SmallVector<int64_t, 1> countDims;
301  SMLoc countLoc = parser.getCurrentLocation();
302  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
303  return Type();
304  if (countDims.size() != 1) {
305  parser.emitError(countLoc,
306  "expected single integer for array element count");
307  return Type();
308  }
309 
310  // According to the SPIR-V spec:
311  // "Length is the number of elements in the array. It must be at least 1."
312  int64_t count = countDims[0];
313  if (count == 0) {
314  parser.emitError(countLoc, "expected array length greater than 0");
315  return Type();
316  }
317 
318  Type elementType = parseAndVerifyType(dialect, parser);
319  if (!elementType)
320  return Type();
321 
322  unsigned stride = 0;
323  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
324  return Type();
325 
326  if (parser.parseGreater())
327  return Type();
328  return ArrayType::get(elementType, count, stride);
329 }
330 
331 // cooperative-matrix-type ::=
332 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
333 // scope `,` use `>`
334 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
335  DialectAsmParser &parser) {
336  if (parser.parseLess())
337  return {};
338 
340  SMLoc countLoc = parser.getCurrentLocation();
341  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
342  return {};
343 
344  if (dims.size() != 2) {
345  parser.emitError(countLoc, "expected row and column count");
346  return {};
347  }
348 
349  auto elementTy = parseAndVerifyType(dialect, parser);
350  if (!elementTy)
351  return {};
352 
353  Scope scope;
354  if (parser.parseComma() ||
355  spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
356  return {};
357 
358  CooperativeMatrixUseKHR use;
359  if (parser.parseComma() ||
360  spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
361  return {};
362 
363  if (parser.parseGreater())
364  return {};
365 
366  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
367 }
368 
369 // TODO: Reorder methods to be utilities first and parse*Type
370 // methods in alphabetical order
371 //
372 // storage-class ::= `UniformConstant`
373 // | `Uniform`
374 // | `Workgroup`
375 // | <and other storage classes...>
376 //
377 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
378 static Type parsePointerType(SPIRVDialect const &dialect,
379  DialectAsmParser &parser) {
380  if (parser.parseLess())
381  return Type();
382 
383  auto pointeeType = parseAndVerifyType(dialect, parser);
384  if (!pointeeType)
385  return Type();
386 
387  StringRef storageClassSpec;
388  SMLoc storageClassLoc = parser.getCurrentLocation();
389  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
390  return Type();
391 
392  auto storageClass = symbolizeStorageClass(storageClassSpec);
393  if (!storageClass) {
394  parser.emitError(storageClassLoc, "unknown storage class: ")
395  << storageClassSpec;
396  return Type();
397  }
398  if (parser.parseGreater())
399  return Type();
400  return PointerType::get(pointeeType, *storageClass);
401 }
402 
403 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
404 // (`,` `stride` `=` integer-literal)? `>`
405 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
406  DialectAsmParser &parser) {
407  if (parser.parseLess())
408  return Type();
409 
410  Type elementType = parseAndVerifyType(dialect, parser);
411  if (!elementType)
412  return Type();
413 
414  unsigned stride = 0;
415  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
416  return Type();
417 
418  if (parser.parseGreater())
419  return Type();
420  return RuntimeArrayType::get(elementType, stride);
421 }
422 
423 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
424 static Type parseMatrixType(SPIRVDialect const &dialect,
425  DialectAsmParser &parser) {
426  if (parser.parseLess())
427  return Type();
428 
429  SmallVector<int64_t, 1> countDims;
430  SMLoc countLoc = parser.getCurrentLocation();
431  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
432  return Type();
433  if (countDims.size() != 1) {
434  parser.emitError(countLoc, "expected single unsigned "
435  "integer for number of columns");
436  return Type();
437  }
438 
439  int64_t columnCount = countDims[0];
440  // According to the specification, Matrices can have 2, 3, or 4 columns
441  if (columnCount < 2 || columnCount > 4) {
442  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
443  "columns");
444  return Type();
445  }
446 
447  Type columnType = parseAndVerifyMatrixType(dialect, parser);
448  if (!columnType)
449  return Type();
450 
451  if (parser.parseGreater())
452  return Type();
453 
454  return MatrixType::get(columnType, columnCount);
455 }
456 
457 // Specialize this function to parse each of the parameters that define an
458 // ImageType. By default it assumes this is an enum type.
459 template <typename ValTy>
460 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
461  DialectAsmParser &parser) {
462  StringRef enumSpec;
463  SMLoc enumLoc = parser.getCurrentLocation();
464  if (parser.parseKeyword(&enumSpec)) {
465  return std::nullopt;
466  }
467 
468  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
469  if (!val)
470  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
471  return val;
472 }
473 
474 template <>
475 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
476  DialectAsmParser &parser) {
477  // TODO: Further verify that the element type can be sampled
478  auto ty = parseAndVerifyType(dialect, parser);
479  if (!ty)
480  return std::nullopt;
481  return ty;
482 }
483 
484 template <typename IntTy>
485 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
486  DialectAsmParser &parser) {
487  IntTy offsetVal = std::numeric_limits<IntTy>::max();
488  if (parser.parseInteger(offsetVal))
489  return std::nullopt;
490  return offsetVal;
491 }
492 
493 template <>
494 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
495  DialectAsmParser &parser) {
496  return parseAndVerifyInteger<unsigned>(dialect, parser);
497 }
498 
499 namespace {
500 // Functor object to parse a comma separated list of specs. The function
501 // parseAndVerify does the actual parsing and verification of individual
502 // elements. This is a functor since parsing the last element of the list
503 // (termination condition) needs partial specialization.
504 template <typename ParseType, typename... Args>
505 struct ParseCommaSeparatedList {
506  std::optional<std::tuple<ParseType, Args...>>
507  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
508  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
509  if (!parseVal)
510  return std::nullopt;
511 
512  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
513  if (numArgs != 0 && failed(parser.parseComma()))
514  return std::nullopt;
515  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
516  if (!remainingValues)
517  return std::nullopt;
518  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
519  remainingValues.value());
520  }
521 };
522 
523 // Partial specialization of the function to parse a comma separated list of
524 // specs to parse the last element of the list.
525 template <typename ParseType>
526 struct ParseCommaSeparatedList<ParseType> {
527  std::optional<std::tuple<ParseType>>
528  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
529  if (auto value = parseAndVerify<ParseType>(dialect, parser))
530  return std::tuple<ParseType>(*value);
531  return std::nullopt;
532  }
533 };
534 } // namespace
535 
536 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
537 //
538 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
539 //
540 // arrayed-info ::= `NonArrayed` | `Arrayed`
541 //
542 // sampling-info ::= `SingleSampled` | `MultiSampled`
543 //
544 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
545 //
546 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
547 //
548 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
549 // arrayed-info `,` sampling-info `,`
550 // sampler-use-info `,` format `>`
551 static Type parseImageType(SPIRVDialect const &dialect,
552  DialectAsmParser &parser) {
553  if (parser.parseLess())
554  return Type();
555 
556  auto value =
557  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
558  ImageSamplingInfo, ImageSamplerUseInfo,
559  ImageFormat>{}(dialect, parser);
560  if (!value)
561  return Type();
562 
563  if (parser.parseGreater())
564  return Type();
565  return ImageType::get(*value);
566 }
567 
568 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
569 static Type parseSampledImageType(SPIRVDialect const &dialect,
570  DialectAsmParser &parser) {
571  if (parser.parseLess())
572  return Type();
573 
574  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
575  if (!parsedType)
576  return Type();
577 
578  if (parser.parseGreater())
579  return Type();
580  return SampledImageType::get(parsedType);
581 }
582 
583 // Parse decorations associated with a member.
584 static ParseResult parseStructMemberDecorations(
585  SPIRVDialect const &dialect, DialectAsmParser &parser,
586  ArrayRef<Type> memberTypes,
589 
590  // Check if the first element is offset.
591  SMLoc offsetLoc = parser.getCurrentLocation();
592  StructType::OffsetInfo offset = 0;
593  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
594  if (offsetParseResult.has_value()) {
595  if (failed(*offsetParseResult))
596  return failure();
597 
598  if (offsetInfo.size() != memberTypes.size() - 1) {
599  return parser.emitError(offsetLoc,
600  "offset specification must be given for "
601  "all members");
602  }
603  offsetInfo.push_back(offset);
604  }
605 
606  // Check for no spirv::Decorations.
607  if (succeeded(parser.parseOptionalRSquare()))
608  return success();
609 
610  // If there was an offset, make sure to parse the comma.
611  if (offsetParseResult.has_value() && parser.parseComma())
612  return failure();
613 
614  // Check for spirv::Decorations.
615  auto parseDecorations = [&]() {
616  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
617  if (!memberDecoration)
618  return failure();
619 
620  // Parse member decoration value if it exists.
621  if (succeeded(parser.parseOptionalEqual())) {
622  auto memberDecorationValue =
623  parseAndVerifyInteger<uint32_t>(dialect, parser);
624 
625  if (!memberDecorationValue)
626  return failure();
627 
628  memberDecorationInfo.emplace_back(
629  static_cast<uint32_t>(memberTypes.size() - 1), 1,
630  memberDecoration.value(), memberDecorationValue.value());
631  } else {
632  memberDecorationInfo.emplace_back(
633  static_cast<uint32_t>(memberTypes.size() - 1), 0,
634  memberDecoration.value(), 0);
635  }
636  return success();
637  };
638  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
639  failed(parser.parseRSquare()))
640  return failure();
641 
642  return success();
643 }
644 
645 // struct-member-decoration ::= integer-literal? spirv-decoration*
646 // struct-type ::=
647 // `!spirv.struct<` (id `,`)?
648 // `(`
649 // (spirv-type (`[` struct-member-decoration `]`)?)*
650 // `)>`
651 static Type parseStructType(SPIRVDialect const &dialect,
652  DialectAsmParser &parser) {
653  // TODO: This function is quite lengthy. Break it down into smaller chunks.
654 
655  if (parser.parseLess())
656  return Type();
657 
658  StringRef identifier;
659  FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
660 
661  // Check if this is an identified struct type.
662  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
663  // Check if this is a possible recursive reference.
664  auto structType =
665  StructType::getIdentified(dialect.getContext(), identifier);
666  cyclicParse = parser.tryStartCyclicParse(structType);
667  if (succeeded(parser.parseOptionalGreater())) {
668  if (succeeded(cyclicParse)) {
669  parser.emitError(
670  parser.getNameLoc(),
671  "recursive struct reference not nested in struct definition");
672 
673  return Type();
674  }
675 
676  return structType;
677  }
678 
679  if (failed(parser.parseComma()))
680  return Type();
681 
682  if (failed(cyclicParse)) {
683  parser.emitError(parser.getNameLoc(),
684  "identifier already used for an enclosing struct");
685  return Type();
686  }
687  }
688 
689  if (failed(parser.parseLParen()))
690  return Type();
691 
692  if (succeeded(parser.parseOptionalRParen()) &&
693  succeeded(parser.parseOptionalGreater())) {
694  return StructType::getEmpty(dialect.getContext(), identifier);
695  }
696 
697  StructType idStructTy;
698 
699  if (!identifier.empty())
700  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
701 
702  SmallVector<Type, 4> memberTypes;
705 
706  do {
707  Type memberType;
708  if (parser.parseType(memberType))
709  return Type();
710  memberTypes.push_back(memberType);
711 
712  if (succeeded(parser.parseOptionalLSquare()))
713  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
714  memberDecorationInfo))
715  return Type();
716  } while (succeeded(parser.parseOptionalComma()));
717 
718  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
719  parser.emitError(parser.getNameLoc(),
720  "offset specification must be given for all members");
721  return Type();
722  }
723 
724  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
725  return Type();
726 
727  if (!identifier.empty()) {
728  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
729  memberDecorationInfo)))
730  return Type();
731  return idStructTy;
732  }
733 
734  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
735 }
736 
737 // spirv-type ::= array-type
738 // | element-type
739 // | image-type
740 // | pointer-type
741 // | runtime-array-type
742 // | sampled-image-type
743 // | struct-type
745  StringRef keyword;
746  if (parser.parseKeyword(&keyword))
747  return Type();
748 
749  if (keyword == "array")
750  return parseArrayType(*this, parser);
751  if (keyword == "coopmatrix")
752  return parseCooperativeMatrixType(*this, parser);
753  if (keyword == "image")
754  return parseImageType(*this, parser);
755  if (keyword == "ptr")
756  return parsePointerType(*this, parser);
757  if (keyword == "rtarray")
758  return parseRuntimeArrayType(*this, parser);
759  if (keyword == "sampled_image")
760  return parseSampledImageType(*this, parser);
761  if (keyword == "struct")
762  return parseStructType(*this, parser);
763  if (keyword == "matrix")
764  return parseMatrixType(*this, parser);
765  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
766  return Type();
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // Type Printing
771 //===----------------------------------------------------------------------===//
772 
773 static void print(ArrayType type, DialectAsmPrinter &os) {
774  os << "array<" << type.getNumElements() << " x " << type.getElementType();
775  if (unsigned stride = type.getArrayStride())
776  os << ", stride=" << stride;
777  os << ">";
778 }
779 
780 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
781  os << "rtarray<" << type.getElementType();
782  if (unsigned stride = type.getArrayStride())
783  os << ", stride=" << stride;
784  os << ">";
785 }
786 
787 static void print(PointerType type, DialectAsmPrinter &os) {
788  os << "ptr<" << type.getPointeeType() << ", "
789  << stringifyStorageClass(type.getStorageClass()) << ">";
790 }
791 
792 static void print(ImageType type, DialectAsmPrinter &os) {
793  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
794  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
795  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
796  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
797  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
798  << stringifyImageFormat(type.getImageFormat()) << ">";
799 }
800 
801 static void print(SampledImageType type, DialectAsmPrinter &os) {
802  os << "sampled_image<" << type.getImageType() << ">";
803 }
804 
805 static void print(StructType type, DialectAsmPrinter &os) {
806  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
807 
808  os << "struct<";
809 
810  if (type.isIdentified()) {
811  os << type.getIdentifier();
812 
813  cyclicPrint = os.tryStartCyclicPrint(type);
814  if (failed(cyclicPrint)) {
815  os << ">";
816  return;
817  }
818 
819  os << ", ";
820  }
821 
822  os << "(";
823 
824  auto printMember = [&](unsigned i) {
825  os << type.getElementType(i);
827  type.getMemberDecorations(i, decorations);
828  if (type.hasOffset() || !decorations.empty()) {
829  os << " [";
830  if (type.hasOffset()) {
831  os << type.getMemberOffset(i);
832  if (!decorations.empty())
833  os << ", ";
834  }
835  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
836  os << stringifyDecoration(decoration.decoration);
837  if (decoration.hasValue) {
838  os << "=" << decoration.decorationValue;
839  }
840  };
841  llvm::interleaveComma(decorations, os, eachFn);
842  os << "]";
843  }
844  };
845  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
846  printMember);
847  os << ")>";
848 }
849 
851  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
852  << type.getElementType() << ", " << type.getScope() << ", "
853  << type.getUse() << ">";
854 }
855 
856 static void print(MatrixType type, DialectAsmPrinter &os) {
857  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
858  os << ">";
859 }
860 
861 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
862  TypeSwitch<Type>(type)
865  [&](auto type) { print(type, os); })
866  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // Constant
871 //===----------------------------------------------------------------------===//
872 
874  Attribute value, Type type,
875  Location loc) {
876  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
877  return builder.create<ub::PoisonOp>(loc, type, poison);
878 
879  if (!spirv::ConstantOp::isBuildableWith(type))
880  return nullptr;
881 
882  return builder.create<spirv::ConstantOp>(loc, type, value);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // Shader Interface ABI
887 //===----------------------------------------------------------------------===//
888 
889 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
890  NamedAttribute attribute) {
891  StringRef symbol = attribute.getName().strref();
892  Attribute attr = attribute.getValue();
893 
894  if (symbol == spirv::getEntryPointABIAttrName()) {
895  if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
896  return op->emitError("'")
897  << symbol << "' attribute must be an entry point ABI attribute";
898  }
899  } else if (symbol == spirv::getTargetEnvAttrName()) {
900  if (!llvm::isa<spirv::TargetEnvAttr>(attr))
901  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
902  } else {
903  return op->emitError("found unsupported '")
904  << symbol << "' attribute on operation";
905  }
906 
907  return success();
908 }
909 
910 /// Verifies the given SPIR-V `attribute` attached to a value of the given
911 /// `valueType` is valid.
912 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
913  NamedAttribute attribute) {
914  StringRef symbol = attribute.getName().strref();
915  Attribute attr = attribute.getValue();
916 
917  if (symbol == spirv::getInterfaceVarABIAttrName()) {
918  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
919  if (!varABIAttr)
920  return emitError(loc, "'")
921  << symbol << "' must be a spirv::InterfaceVarABIAttr";
922 
923  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
924  return emitError(loc, "'") << symbol
925  << "' attribute cannot specify storage class "
926  "when attaching to a non-scalar value";
927  return success();
928  }
929  if (symbol == spirv::DecorationAttr::name) {
930  if (!isa<spirv::DecorationAttr>(attr))
931  return emitError(loc, "'")
932  << symbol << "' must be a spirv::DecorationAttr";
933  return success();
934  }
935 
936  return emitError(loc, "found unsupported '")
937  << symbol << "' attribute on region argument";
938 }
939 
940 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
941  unsigned regionIndex,
942  unsigned argIndex,
943  NamedAttribute attribute) {
944  auto funcOp = dyn_cast<FunctionOpInterface>(op);
945  if (!funcOp)
946  return success();
947  Type argType = funcOp.getArgumentTypes()[argIndex];
948 
949  return verifyRegionAttribute(op->getLoc(), argType, attribute);
950 }
951 
952 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
953  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
954  NamedAttribute attribute) {
955  return op->emitError("cannot attach SPIR-V attributes to region result");
956 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
Scope getScope() const
Returns the scope of the matrix.
Definition: SPIRVTypes.cpp:240
uint32_t getRows() const
Returns the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
uint32_t getColumns() const
Returns the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:222
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
Definition: SPIRVTypes.cpp:242
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:168
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:349
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:351
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:363
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:359
Type getElementType() const
Definition: SPIRVTypes.cpp:345
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:355
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:410
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:412
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:473
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:732
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:507
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:974
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:998
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:984
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
uint64_t getMemberOffset(unsigned) const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.