MLIR  16.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser/Parser.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace mlir;
33 using namespace mlir::spirv;
34 
35 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
36 
37 //===----------------------------------------------------------------------===//
38 // InlinerInterface
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
42 /// ops.
43 static inline bool containsReturn(Region &region) {
44  return llvm::any_of(region, [](Block &block) {
45  Operation *terminator = block.getTerminator();
46  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47  });
48 }
49 
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
54 
55  /// All call operations within SPIRV can be inlined.
56  bool isLegalToInline(Operation *call, Operation *callable,
57  bool wouldBeCloned) const final {
58  return true;
59  }
60 
61  /// Returns true if the given region 'src' can be inlined into the region
62  /// 'dest' that is attached to an operation registered to the current dialect.
63  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64  BlockAndValueMapping &) const final {
65  // Return true here when inlining into spirv.func, spirv.mlir.selection, and
66  // spirv.mlir.loop operations.
67  auto *op = dest->getParentOp();
68  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69  }
70 
71  /// Returns true if the given operation 'op', that is registered to this
72  /// dialect, can be inlined into the region 'dest' that is attached to an
73  /// operation registered to the current dialect.
74  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75  BlockAndValueMapping &) const final {
76  // TODO: Enable inlining structured control flows with return.
77  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78  containsReturn(op->getRegion(0)))
79  return false;
80  // TODO: we need to filter OpKill here to avoid inlining it to
81  // a loop continue construct:
82  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83  // However OpKill is fragment shader specific and we don't support it yet.
84  return true;
85  }
86 
87  /// Handle the given inlined terminator by replacing it with a new operation
88  /// as necessary.
89  void handleTerminator(Operation *op, Block *newDest) const final {
90  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
91  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
92  op->erase();
93  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
94  llvm_unreachable("unimplemented spirv.ReturnValue in inliner");
95  }
96  }
97 
98  /// Handle the given inlined terminator by replacing it with a new operation
99  /// as necessary.
100  void handleTerminator(Operation *op,
101  ArrayRef<Value> valuesToRepl) const final {
102  // Only spirv.ReturnValue needs to be handled here.
103  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
104  if (!retValOp)
105  return;
106 
107  // Replace the values directly with the return operands.
108  assert(valuesToRepl.size() == 1 &&
109  "spirv.ReturnValue expected to only handle one result");
110  valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
111  }
112 };
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // SPIR-V Dialect
117 //===----------------------------------------------------------------------===//
118 
119 void SPIRVDialect::initialize() {
120  registerAttributes();
121  registerTypes();
122 
123  // Add SPIR-V ops.
124  addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127  >();
128 
129  addInterfaces<SPIRVInlinerInterface>();
130 
131  // Allow unknown operations because SPIR-V is extensible.
132  allowUnknownOperations();
133 }
134 
135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
142 
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
146  DialectAsmParser &parser);
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149  DialectAsmParser &parser);
150 
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153  DialectAsmParser &parser);
154 
155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156  DialectAsmParser &parser) {
157  Type type;
158  SMLoc typeLoc = parser.getCurrentLocation();
159  if (parser.parseType(type))
160  return Type();
161 
162  // Allow SPIR-V dialect types
163  if (&type.getDialect() == &dialect)
164  return type;
165 
166  // Check other allowed types
167  if (auto t = type.dyn_cast<FloatType>()) {
168  if (type.isBF16()) {
169  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170  return Type();
171  }
172  } else if (auto t = type.dyn_cast<IntegerType>()) {
173  if (!ScalarType::isValid(t)) {
174  parser.emitError(typeLoc,
175  "only 1/8/16/32/64-bit integer type allowed but found ")
176  << type;
177  return Type();
178  }
179  } else if (auto t = type.dyn_cast<VectorType>()) {
180  if (t.getRank() != 1) {
181  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182  return Type();
183  }
184  if (t.getNumElements() > 4) {
185  parser.emitError(
186  typeLoc, "vector length has to be less than or equal to 4 but found ")
187  << t.getNumElements();
188  return Type();
189  }
190  } else {
191  parser.emitError(typeLoc, "cannot use ")
192  << type << " to compose SPIR-V types";
193  return Type();
194  }
195 
196  return type;
197 }
198 
199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200  DialectAsmParser &parser) {
201  Type type;
202  SMLoc typeLoc = parser.getCurrentLocation();
203  if (parser.parseType(type))
204  return Type();
205 
206  if (auto t = type.dyn_cast<VectorType>()) {
207  if (t.getRank() != 1) {
208  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209  return Type();
210  }
211  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212  parser.emitError(typeLoc,
213  "matrix columns size has to be less than or equal "
214  "to 4 and greater than or equal 2, but found ")
215  << t.getNumElements();
216  return Type();
217  }
218 
219  if (!t.getElementType().isa<FloatType>()) {
220  parser.emitError(typeLoc, "matrix columns' elements must be of "
221  "Float type, got ")
222  << t.getElementType();
223  return Type();
224  }
225  } else {
226  parser.emitError(typeLoc, "matrix must be composed using vector "
227  "type, got ")
228  << type;
229  return Type();
230  }
231 
232  return type;
233 }
234 
235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
236  DialectAsmParser &parser) {
237  Type type;
238  SMLoc typeLoc = parser.getCurrentLocation();
239  if (parser.parseType(type))
240  return Type();
241 
242  if (!type.isa<ImageType>()) {
243  parser.emitError(typeLoc,
244  "sampled image must be composed using image type, got ")
245  << type;
246  return Type();
247  }
248 
249  return type;
250 }
251 
252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
254 /// missing.
255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
256  DialectAsmParser &parser,
257  unsigned &stride) {
258  if (failed(parser.parseOptionalComma())) {
259  stride = 0;
260  return success();
261  }
262 
263  if (parser.parseKeyword("stride") || parser.parseEqual())
264  return failure();
265 
266  SMLoc strideLoc = parser.getCurrentLocation();
267  Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
268  if (!optStride)
269  return failure();
270 
271  if (!(stride = *optStride)) {
272  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
273  return failure();
274  }
275  return success();
276 }
277 
278 // element-type ::= integer-type
279 // | floating-point-type
280 // | vector-type
281 // | spirv-type
282 //
283 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
284 // (`,` `stride` `=` integer-literal)? `>`
285 static Type parseArrayType(SPIRVDialect const &dialect,
286  DialectAsmParser &parser) {
287  if (parser.parseLess())
288  return Type();
289 
290  SmallVector<int64_t, 1> countDims;
291  SMLoc countLoc = parser.getCurrentLocation();
292  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
293  return Type();
294  if (countDims.size() != 1) {
295  parser.emitError(countLoc,
296  "expected single integer for array element count");
297  return Type();
298  }
299 
300  // According to the SPIR-V spec:
301  // "Length is the number of elements in the array. It must be at least 1."
302  int64_t count = countDims[0];
303  if (count == 0) {
304  parser.emitError(countLoc, "expected array length greater than 0");
305  return Type();
306  }
307 
308  Type elementType = parseAndVerifyType(dialect, parser);
309  if (!elementType)
310  return Type();
311 
312  unsigned stride = 0;
313  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
314  return Type();
315 
316  if (parser.parseGreater())
317  return Type();
318  return ArrayType::get(elementType, count, stride);
319 }
320 
321 // cooperative-matrix-type ::= `!spirv.coopmatrix` `<` element-type ',' scope
322 // ','
323 // rows ',' columns>`
324 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
325  DialectAsmParser &parser) {
326  if (parser.parseLess())
327  return Type();
328 
330  SMLoc countLoc = parser.getCurrentLocation();
331  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
332  return Type();
333 
334  if (dims.size() != 2) {
335  parser.emitError(countLoc, "expected rows and columns size");
336  return Type();
337  }
338 
339  auto elementTy = parseAndVerifyType(dialect, parser);
340  if (!elementTy)
341  return Type();
342 
343  Scope scope;
344  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
345  return Type();
346 
347  if (parser.parseGreater())
348  return Type();
349  return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
350 }
351 
352 // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
353 // element-type
354 // `,` layout `,` scope`>`
355 static Type parseJointMatrixType(SPIRVDialect const &dialect,
356  DialectAsmParser &parser) {
357  if (parser.parseLess())
358  return Type();
359 
361  SMLoc countLoc = parser.getCurrentLocation();
362  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
363  return Type();
364 
365  if (dims.size() != 2) {
366  parser.emitError(countLoc, "expected rows and columns size");
367  return Type();
368  }
369 
370  auto elementTy = parseAndVerifyType(dialect, parser);
371  if (!elementTy)
372  return Type();
373  MatrixLayout matrixLayout;
374  if (parser.parseComma() ||
375  parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
376  return Type();
377  Scope scope;
378  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
379  return Type();
380  if (parser.parseGreater())
381  return Type();
382  return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
383  matrixLayout);
384 }
385 
386 // TODO: Reorder methods to be utilities first and parse*Type
387 // methods in alphabetical order
388 //
389 // storage-class ::= `UniformConstant`
390 // | `Uniform`
391 // | `Workgroup`
392 // | <and other storage classes...>
393 //
394 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
395 static Type parsePointerType(SPIRVDialect const &dialect,
396  DialectAsmParser &parser) {
397  if (parser.parseLess())
398  return Type();
399 
400  auto pointeeType = parseAndVerifyType(dialect, parser);
401  if (!pointeeType)
402  return Type();
403 
404  StringRef storageClassSpec;
405  SMLoc storageClassLoc = parser.getCurrentLocation();
406  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
407  return Type();
408 
409  auto storageClass = symbolizeStorageClass(storageClassSpec);
410  if (!storageClass) {
411  parser.emitError(storageClassLoc, "unknown storage class: ")
412  << storageClassSpec;
413  return Type();
414  }
415  if (parser.parseGreater())
416  return Type();
417  return PointerType::get(pointeeType, *storageClass);
418 }
419 
420 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
421 // (`,` `stride` `=` integer-literal)? `>`
422 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
423  DialectAsmParser &parser) {
424  if (parser.parseLess())
425  return Type();
426 
427  Type elementType = parseAndVerifyType(dialect, parser);
428  if (!elementType)
429  return Type();
430 
431  unsigned stride = 0;
432  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
433  return Type();
434 
435  if (parser.parseGreater())
436  return Type();
437  return RuntimeArrayType::get(elementType, stride);
438 }
439 
440 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
441 static Type parseMatrixType(SPIRVDialect const &dialect,
442  DialectAsmParser &parser) {
443  if (parser.parseLess())
444  return Type();
445 
446  SmallVector<int64_t, 1> countDims;
447  SMLoc countLoc = parser.getCurrentLocation();
448  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
449  return Type();
450  if (countDims.size() != 1) {
451  parser.emitError(countLoc, "expected single unsigned "
452  "integer for number of columns");
453  return Type();
454  }
455 
456  int64_t columnCount = countDims[0];
457  // According to the specification, Matrices can have 2, 3, or 4 columns
458  if (columnCount < 2 || columnCount > 4) {
459  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
460  "columns");
461  return Type();
462  }
463 
464  Type columnType = parseAndVerifyMatrixType(dialect, parser);
465  if (!columnType)
466  return Type();
467 
468  if (parser.parseGreater())
469  return Type();
470 
471  return MatrixType::get(columnType, columnCount);
472 }
473 
474 // Specialize this function to parse each of the parameters that define an
475 // ImageType. By default it assumes this is an enum type.
476 template <typename ValTy>
477 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
478  DialectAsmParser &parser) {
479  StringRef enumSpec;
480  SMLoc enumLoc = parser.getCurrentLocation();
481  if (parser.parseKeyword(&enumSpec)) {
482  return llvm::None;
483  }
484 
485  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
486  if (!val)
487  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
488  return val;
489 }
490 
491 template <>
492 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
493  DialectAsmParser &parser) {
494  // TODO: Further verify that the element type can be sampled
495  auto ty = parseAndVerifyType(dialect, parser);
496  if (!ty)
497  return llvm::None;
498  return ty;
499 }
500 
501 template <typename IntTy>
502 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
503  DialectAsmParser &parser) {
504  IntTy offsetVal = std::numeric_limits<IntTy>::max();
505  if (parser.parseInteger(offsetVal))
506  return llvm::None;
507  return offsetVal;
508 }
509 
510 template <>
511 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
512  DialectAsmParser &parser) {
513  return parseAndVerifyInteger<unsigned>(dialect, parser);
514 }
515 
516 namespace {
517 // Functor object to parse a comma separated list of specs. The function
518 // parseAndVerify does the actual parsing and verification of individual
519 // elements. This is a functor since parsing the last element of the list
520 // (termination condition) needs partial specialization.
521 template <typename ParseType, typename... Args>
522 struct ParseCommaSeparatedList {
523  Optional<std::tuple<ParseType, Args...>>
524  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
525  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
526  if (!parseVal)
527  return llvm::None;
528 
529  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
530  if (numArgs != 0 && failed(parser.parseComma()))
531  return llvm::None;
532  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
533  if (!remainingValues)
534  return llvm::None;
535  return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
536  remainingValues.value());
537  }
538 };
539 
540 // Partial specialization of the function to parse a comma separated list of
541 // specs to parse the last element of the list.
542 template <typename ParseType>
543 struct ParseCommaSeparatedList<ParseType> {
544  Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
545  DialectAsmParser &parser) const {
546  if (auto value = parseAndVerify<ParseType>(dialect, parser))
547  return std::tuple<ParseType>(*value);
548  return llvm::None;
549  }
550 };
551 } // namespace
552 
553 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
554 //
555 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
556 //
557 // arrayed-info ::= `NonArrayed` | `Arrayed`
558 //
559 // sampling-info ::= `SingleSampled` | `MultiSampled`
560 //
561 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
562 //
563 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
564 //
565 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
566 // arrayed-info `,` sampling-info `,`
567 // sampler-use-info `,` format `>`
568 static Type parseImageType(SPIRVDialect const &dialect,
569  DialectAsmParser &parser) {
570  if (parser.parseLess())
571  return Type();
572 
573  auto value =
574  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
575  ImageSamplingInfo, ImageSamplerUseInfo,
576  ImageFormat>{}(dialect, parser);
577  if (!value)
578  return Type();
579 
580  if (parser.parseGreater())
581  return Type();
582  return ImageType::get(*value);
583 }
584 
585 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
586 static Type parseSampledImageType(SPIRVDialect const &dialect,
587  DialectAsmParser &parser) {
588  if (parser.parseLess())
589  return Type();
590 
591  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
592  if (!parsedType)
593  return Type();
594 
595  if (parser.parseGreater())
596  return Type();
597  return SampledImageType::get(parsedType);
598 }
599 
600 // Parse decorations associated with a member.
602  SPIRVDialect const &dialect, DialectAsmParser &parser,
603  ArrayRef<Type> memberTypes,
606 
607  // Check if the first element is offset.
608  SMLoc offsetLoc = parser.getCurrentLocation();
609  StructType::OffsetInfo offset = 0;
610  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
611  if (offsetParseResult.has_value()) {
612  if (failed(*offsetParseResult))
613  return failure();
614 
615  if (offsetInfo.size() != memberTypes.size() - 1) {
616  return parser.emitError(offsetLoc,
617  "offset specification must be given for "
618  "all members");
619  }
620  offsetInfo.push_back(offset);
621  }
622 
623  // Check for no spirv::Decorations.
624  if (succeeded(parser.parseOptionalRSquare()))
625  return success();
626 
627  // If there was an offset, make sure to parse the comma.
628  if (offsetParseResult.has_value() && parser.parseComma())
629  return failure();
630 
631  // Check for spirv::Decorations.
632  auto parseDecorations = [&]() {
633  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
634  if (!memberDecoration)
635  return failure();
636 
637  // Parse member decoration value if it exists.
638  if (succeeded(parser.parseOptionalEqual())) {
639  auto memberDecorationValue =
640  parseAndVerifyInteger<uint32_t>(dialect, parser);
641 
642  if (!memberDecorationValue)
643  return failure();
644 
645  memberDecorationInfo.emplace_back(
646  static_cast<uint32_t>(memberTypes.size() - 1), 1,
647  memberDecoration.value(), memberDecorationValue.value());
648  } else {
649  memberDecorationInfo.emplace_back(
650  static_cast<uint32_t>(memberTypes.size() - 1), 0,
651  memberDecoration.value(), 0);
652  }
653  return success();
654  };
655  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
656  failed(parser.parseRSquare()))
657  return failure();
658 
659  return success();
660 }
661 
662 // struct-member-decoration ::= integer-literal? spirv-decoration*
663 // struct-type ::=
664 // `!spirv.struct<` (id `,`)?
665 // `(`
666 // (spirv-type (`[` struct-member-decoration `]`)?)*
667 // `)>`
668 static Type parseStructType(SPIRVDialect const &dialect,
669  DialectAsmParser &parser) {
670  // TODO: This function is quite lengthy. Break it down into smaller chunks.
671 
672  // To properly resolve recursive references while parsing recursive struct
673  // types, we need to maintain a list of enclosing struct type names. This set
674  // maintains the names of struct types in which the type we are about to parse
675  // is nested.
676  //
677  // Note: This has to be thread_local to enable multiple threads to safely
678  // parse concurrently.
679  thread_local SetVector<StringRef> structContext;
680 
681  static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
682  StringRef identifier) {
683  if (!identifier.empty())
684  structContext.remove(identifier);
685 
686  return Type();
687  };
688 
689  if (parser.parseLess())
690  return Type();
691 
692  StringRef identifier;
693 
694  // Check if this is an identified struct type.
695  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
696  // Check if this is a possible recursive reference.
697  if (succeeded(parser.parseOptionalGreater())) {
698  if (structContext.count(identifier) == 0) {
699  parser.emitError(
700  parser.getNameLoc(),
701  "recursive struct reference not nested in struct definition");
702 
703  return Type();
704  }
705 
706  return StructType::getIdentified(dialect.getContext(), identifier);
707  }
708 
709  if (failed(parser.parseComma()))
710  return Type();
711 
712  if (structContext.count(identifier) != 0) {
713  parser.emitError(parser.getNameLoc(),
714  "identifier already used for an enclosing struct");
715 
716  return removeIdentifierAndFail(structContext, identifier);
717  }
718 
719  structContext.insert(identifier);
720  }
721 
722  if (failed(parser.parseLParen()))
723  return removeIdentifierAndFail(structContext, identifier);
724 
725  if (succeeded(parser.parseOptionalRParen()) &&
726  succeeded(parser.parseOptionalGreater())) {
727  if (!identifier.empty())
728  structContext.remove(identifier);
729 
730  return StructType::getEmpty(dialect.getContext(), identifier);
731  }
732 
733  StructType idStructTy;
734 
735  if (!identifier.empty())
736  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
737 
738  SmallVector<Type, 4> memberTypes;
741 
742  do {
743  Type memberType;
744  if (parser.parseType(memberType))
745  return removeIdentifierAndFail(structContext, identifier);
746  memberTypes.push_back(memberType);
747 
748  if (succeeded(parser.parseOptionalLSquare()))
749  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
750  memberDecorationInfo))
751  return removeIdentifierAndFail(structContext, identifier);
752  } while (succeeded(parser.parseOptionalComma()));
753 
754  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
755  parser.emitError(parser.getNameLoc(),
756  "offset specification must be given for all members");
757  return removeIdentifierAndFail(structContext, identifier);
758  }
759 
760  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
761  return removeIdentifierAndFail(structContext, identifier);
762 
763  if (!identifier.empty()) {
764  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
765  memberDecorationInfo)))
766  return Type();
767 
768  structContext.remove(identifier);
769  return idStructTy;
770  }
771 
772  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
773 }
774 
775 // spirv-type ::= array-type
776 // | element-type
777 // | image-type
778 // | pointer-type
779 // | runtime-array-type
780 // | sampled-image-type
781 // | struct-type
783  StringRef keyword;
784  if (parser.parseKeyword(&keyword))
785  return Type();
786 
787  if (keyword == "array")
788  return parseArrayType(*this, parser);
789  if (keyword == "coopmatrix")
790  return parseCooperativeMatrixType(*this, parser);
791  if (keyword == "jointmatrix")
792  return parseJointMatrixType(*this, parser);
793  if (keyword == "image")
794  return parseImageType(*this, parser);
795  if (keyword == "ptr")
796  return parsePointerType(*this, parser);
797  if (keyword == "rtarray")
798  return parseRuntimeArrayType(*this, parser);
799  if (keyword == "sampled_image")
800  return parseSampledImageType(*this, parser);
801  if (keyword == "struct")
802  return parseStructType(*this, parser);
803  if (keyword == "matrix")
804  return parseMatrixType(*this, parser);
805  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
806  return Type();
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // Type Printing
811 //===----------------------------------------------------------------------===//
812 
813 static void print(ArrayType type, DialectAsmPrinter &os) {
814  os << "array<" << type.getNumElements() << " x " << type.getElementType();
815  if (unsigned stride = type.getArrayStride())
816  os << ", stride=" << stride;
817  os << ">";
818 }
819 
820 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
821  os << "rtarray<" << type.getElementType();
822  if (unsigned stride = type.getArrayStride())
823  os << ", stride=" << stride;
824  os << ">";
825 }
826 
827 static void print(PointerType type, DialectAsmPrinter &os) {
828  os << "ptr<" << type.getPointeeType() << ", "
829  << stringifyStorageClass(type.getStorageClass()) << ">";
830 }
831 
832 static void print(ImageType type, DialectAsmPrinter &os) {
833  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
834  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
835  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
836  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
837  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
838  << stringifyImageFormat(type.getImageFormat()) << ">";
839 }
840 
841 static void print(SampledImageType type, DialectAsmPrinter &os) {
842  os << "sampled_image<" << type.getImageType() << ">";
843 }
844 
845 static void print(StructType type, DialectAsmPrinter &os) {
846  thread_local SetVector<StringRef> structContext;
847 
848  os << "struct<";
849 
850  if (type.isIdentified()) {
851  os << type.getIdentifier();
852 
853  if (structContext.count(type.getIdentifier())) {
854  os << ">";
855  return;
856  }
857 
858  os << ", ";
859  structContext.insert(type.getIdentifier());
860  }
861 
862  os << "(";
863 
864  auto printMember = [&](unsigned i) {
865  os << type.getElementType(i);
867  type.getMemberDecorations(i, decorations);
868  if (type.hasOffset() || !decorations.empty()) {
869  os << " [";
870  if (type.hasOffset()) {
871  os << type.getMemberOffset(i);
872  if (!decorations.empty())
873  os << ", ";
874  }
875  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
876  os << stringifyDecoration(decoration.decoration);
877  if (decoration.hasValue) {
878  os << "=" << decoration.decorationValue;
879  }
880  };
881  llvm::interleaveComma(decorations, os, eachFn);
882  os << "]";
883  }
884  };
885  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
886  printMember);
887  os << ")>";
888 
889  if (type.isIdentified())
890  structContext.remove(type.getIdentifier());
891 }
892 
894  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
895  os << type.getElementType() << ", " << stringifyScope(type.getScope());
896  os << ">";
897 }
898 
900  os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
901  os << type.getElementType() << ", "
902  << stringifyMatrixLayout(type.getMatrixLayout());
903  os << ", " << stringifyScope(type.getScope()) << ">";
904 }
905 
906 static void print(MatrixType type, DialectAsmPrinter &os) {
907  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
908  os << ">";
909 }
910 
911 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
912  TypeSwitch<Type>(type)
915  StructType, MatrixType>([&](auto type) { print(type, os); })
916  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // Constant
921 //===----------------------------------------------------------------------===//
922 
924  Attribute value, Type type,
925  Location loc) {
926  if (!spirv::ConstantOp::isBuildableWith(type))
927  return nullptr;
928 
929  return builder.create<spirv::ConstantOp>(loc, type, value);
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // Shader Interface ABI
934 //===----------------------------------------------------------------------===//
935 
936 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
937  NamedAttribute attribute) {
938  StringRef symbol = attribute.getName().strref();
939  Attribute attr = attribute.getValue();
940 
941  if (symbol == spirv::getEntryPointABIAttrName()) {
942  if (!attr.isa<spirv::EntryPointABIAttr>()) {
943  return op->emitError("'")
944  << symbol << "' attribute must be an entry point ABI attribute";
945  }
946  } else if (symbol == spirv::getTargetEnvAttrName()) {
947  if (!attr.isa<spirv::TargetEnvAttr>())
948  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
949  } else {
950  return op->emitError("found unsupported '")
951  << symbol << "' attribute on operation";
952  }
953 
954  return success();
955 }
956 
957 /// Verifies the given SPIR-V `attribute` attached to a value of the given
958 /// `valueType` is valid.
960  NamedAttribute attribute) {
961  StringRef symbol = attribute.getName().strref();
962  Attribute attr = attribute.getValue();
963 
964  if (symbol != spirv::getInterfaceVarABIAttrName())
965  return emitError(loc, "found unsupported '")
966  << symbol << "' attribute on region argument";
967 
968  auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
969  if (!varABIAttr)
970  return emitError(loc, "'")
971  << symbol << "' must be a spirv::InterfaceVarABIAttr";
972 
973  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
974  return emitError(loc, "'") << symbol
975  << "' attribute cannot specify storage class "
976  "when attaching to a non-scalar value";
977 
978  return success();
979 }
980 
981 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
982  unsigned regionIndex,
983  unsigned argIndex,
984  NamedAttribute attribute) {
985  return verifyRegionAttribute(
986  op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
987  attribute);
988 }
989 
990 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
991  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
992  NamedAttribute attribute) {
993  return op->emitError("cannot attach SPIR-V attributes to region result");
994 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static constexpr const bool value
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseJointMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
Optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region &region)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
Optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
bool isa() const
Casting utility functions.
Definition: Attributes.h:117
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:41
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:121
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:91
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
bool isBF16() const
Definition: Types.cpp:23
Type getType() const
Return the type of this value.
Definition: Value.h:114
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:242
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:244
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:230
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:240
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:164
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:420
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:422
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:434
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:430
Type getElementType() const
Definition: SPIRVTypes.cpp:416
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:426
An attribute that specifies the information regarding the interface variable: descriptor set,...
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:306
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:310
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
Definition: SPIRVTypes.cpp:295
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:308
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Definition: SPIRVTypes.cpp:312
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:480
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:482
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:476
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:543
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:533
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:808
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:576
SPIR-V struct type.
Definition: SPIRVTypes.h:281
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
An attribute that specifies the target version, allowed extensions and capabilities,...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type parseType(llvm::StringRef typeStr, MLIRContext *context)
This parses a single MLIR type to an MLIR context if it was valid.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26