MLIR  15.0.0git
SPIRVDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser/Parser.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 using namespace mlir::spirv;
35 
36 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // InlinerInterface
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
43 static inline bool containsReturn(Region &region) {
44  return llvm::any_of(region, [](Block &block) {
45  Operation *terminator = block.getTerminator();
46  return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47  });
48 }
49 
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
54 
55  /// All call operations within SPIRV can be inlined.
56  bool isLegalToInline(Operation *call, Operation *callable,
57  bool wouldBeCloned) const final {
58  return true;
59  }
60 
61  /// Returns true if the given region 'src' can be inlined into the region
62  /// 'dest' that is attached to an operation registered to the current dialect.
63  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64  BlockAndValueMapping &) const final {
65  // Return true here when inlining into spv.func, spv.mlir.selection, and
66  // spv.mlir.loop operations.
67  auto *op = dest->getParentOp();
68  return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69  }
70 
71  /// Returns true if the given operation 'op', that is registered to this
72  /// dialect, can be inlined into the region 'dest' that is attached to an
73  /// operation registered to the current dialect.
74  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75  BlockAndValueMapping &) const final {
76  // TODO: Enable inlining structured control flows with return.
77  if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78  containsReturn(op->getRegion(0)))
79  return false;
80  // TODO: we need to filter OpKill here to avoid inlining it to
81  // a loop continue construct:
82  // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83  // However OpKill is fragment shader specific and we don't support it yet.
84  return true;
85  }
86 
87  /// Handle the given inlined terminator by replacing it with a new operation
88  /// as necessary.
89  void handleTerminator(Operation *op, Block *newDest) const final {
90  if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
91  OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
92  op->erase();
93  } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
94  llvm_unreachable("unimplemented spv.ReturnValue in inliner");
95  }
96  }
97 
98  /// Handle the given inlined terminator by replacing it with a new operation
99  /// as necessary.
100  void handleTerminator(Operation *op,
101  ArrayRef<Value> valuesToRepl) const final {
102  // Only spv.ReturnValue needs to be handled here.
103  auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
104  if (!retValOp)
105  return;
106 
107  // Replace the values directly with the return operands.
108  assert(valuesToRepl.size() == 1 &&
109  "spv.ReturnValue expected to only handle one result");
110  valuesToRepl.front().replaceAllUsesWith(retValOp.value());
111  }
112 };
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // SPIR-V Dialect
117 //===----------------------------------------------------------------------===//
118 
119 void SPIRVDialect::initialize() {
120  registerAttributes();
121  registerTypes();
122 
123  // Add SPIR-V ops.
124  addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127  >();
128 
129  addInterfaces<SPIRVInlinerInterface>();
130 
131  // Allow unknown operations because SPIR-V is extensible.
132  allowUnknownOperations();
133 }
134 
135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
142 
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149  DialectAsmParser &parser);
150 
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153  DialectAsmParser &parser);
154 
155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156  DialectAsmParser &parser) {
157  Type type;
158  SMLoc typeLoc = parser.getCurrentLocation();
159  if (parser.parseType(type))
160  return Type();
161 
162  // Allow SPIR-V dialect types
163  if (&type.getDialect() == &dialect)
164  return type;
165 
166  // Check other allowed types
167  if (auto t = type.dyn_cast<FloatType>()) {
168  if (type.isBF16()) {
169  parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170  return Type();
171  }
172  } else if (auto t = type.dyn_cast<IntegerType>()) {
173  if (!ScalarType::isValid(t)) {
174  parser.emitError(typeLoc,
175  "only 1/8/16/32/64-bit integer type allowed but found ")
176  << type;
177  return Type();
178  }
179  } else if (auto t = type.dyn_cast<VectorType>()) {
180  if (t.getRank() != 1) {
181  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182  return Type();
183  }
184  if (t.getNumElements() > 4) {
185  parser.emitError(
186  typeLoc, "vector length has to be less than or equal to 4 but found ")
187  << t.getNumElements();
188  return Type();
189  }
190  } else {
191  parser.emitError(typeLoc, "cannot use ")
192  << type << " to compose SPIR-V types";
193  return Type();
194  }
195 
196  return type;
197 }
198 
199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200  DialectAsmParser &parser) {
201  Type type;
202  SMLoc typeLoc = parser.getCurrentLocation();
203  if (parser.parseType(type))
204  return Type();
205 
206  if (auto t = type.dyn_cast<VectorType>()) {
207  if (t.getRank() != 1) {
208  parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209  return Type();
210  }
211  if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212  parser.emitError(typeLoc,
213  "matrix columns size has to be less than or equal "
214  "to 4 and greater than or equal 2, but found ")
215  << t.getNumElements();
216  return Type();
217  }
218 
219  if (!t.getElementType().isa<FloatType>()) {
220  parser.emitError(typeLoc, "matrix columns' elements must be of "
221  "Float type, got ")
222  << t.getElementType();
223  return Type();
224  }
225  } else {
226  parser.emitError(typeLoc, "matrix must be composed using vector "
227  "type, got ")
228  << type;
229  return Type();
230  }
231 
232  return type;
233 }
234 
235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
236  DialectAsmParser &parser) {
237  Type type;
238  SMLoc typeLoc = parser.getCurrentLocation();
239  if (parser.parseType(type))
240  return Type();
241 
242  if (!type.isa<ImageType>()) {
243  parser.emitError(typeLoc,
244  "sampled image must be composed using image type, got ")
245  << type;
246  return Type();
247  }
248 
249  return type;
250 }
251 
252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
254 /// missing.
255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
256  DialectAsmParser &parser,
257  unsigned &stride) {
258  if (failed(parser.parseOptionalComma())) {
259  stride = 0;
260  return success();
261  }
262 
263  if (parser.parseKeyword("stride") || parser.parseEqual())
264  return failure();
265 
266  SMLoc strideLoc = parser.getCurrentLocation();
267  Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
268  if (!optStride)
269  return failure();
270 
271  if (!(stride = *optStride)) {
272  parser.emitError(strideLoc, "ArrayStride must be greater than zero");
273  return failure();
274  }
275  return success();
276 }
277 
278 // element-type ::= integer-type
279 // | floating-point-type
280 // | vector-type
281 // | spirv-type
282 //
283 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
284 // (`,` `stride` `=` integer-literal)? `>`
285 static Type parseArrayType(SPIRVDialect const &dialect,
286  DialectAsmParser &parser) {
287  if (parser.parseLess())
288  return Type();
289 
290  SmallVector<int64_t, 1> countDims;
291  SMLoc countLoc = parser.getCurrentLocation();
292  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
293  return Type();
294  if (countDims.size() != 1) {
295  parser.emitError(countLoc,
296  "expected single integer for array element count");
297  return Type();
298  }
299 
300  // According to the SPIR-V spec:
301  // "Length is the number of elements in the array. It must be at least 1."
302  int64_t count = countDims[0];
303  if (count == 0) {
304  parser.emitError(countLoc, "expected array length greater than 0");
305  return Type();
306  }
307 
308  Type elementType = parseAndVerifyType(dialect, parser);
309  if (!elementType)
310  return Type();
311 
312  unsigned stride = 0;
313  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
314  return Type();
315 
316  if (parser.parseGreater())
317  return Type();
318  return ArrayType::get(elementType, count, stride);
319 }
320 
321 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
322 // rows ',' columns>`
323 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
324  DialectAsmParser &parser) {
325  if (parser.parseLess())
326  return Type();
327 
329  SMLoc countLoc = parser.getCurrentLocation();
330  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
331  return Type();
332 
333  if (dims.size() != 2) {
334  parser.emitError(countLoc, "expected rows and columns size");
335  return Type();
336  }
337 
338  auto elementTy = parseAndVerifyType(dialect, parser);
339  if (!elementTy)
340  return Type();
341 
342  Scope scope;
343  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
344  return Type();
345 
346  if (parser.parseGreater())
347  return Type();
348  return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
349 }
350 
351 // TODO: Reorder methods to be utilities first and parse*Type
352 // methods in alphabetical order
353 //
354 // storage-class ::= `UniformConstant`
355 // | `Uniform`
356 // | `Workgroup`
357 // | <and other storage classes...>
358 //
359 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
360 static Type parsePointerType(SPIRVDialect const &dialect,
361  DialectAsmParser &parser) {
362  if (parser.parseLess())
363  return Type();
364 
365  auto pointeeType = parseAndVerifyType(dialect, parser);
366  if (!pointeeType)
367  return Type();
368 
369  StringRef storageClassSpec;
370  SMLoc storageClassLoc = parser.getCurrentLocation();
371  if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
372  return Type();
373 
374  auto storageClass = symbolizeStorageClass(storageClassSpec);
375  if (!storageClass) {
376  parser.emitError(storageClassLoc, "unknown storage class: ")
377  << storageClassSpec;
378  return Type();
379  }
380  if (parser.parseGreater())
381  return Type();
382  return PointerType::get(pointeeType, *storageClass);
383 }
384 
385 // runtime-array-type ::= `!spv.rtarray` `<` element-type
386 // (`,` `stride` `=` integer-literal)? `>`
387 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
388  DialectAsmParser &parser) {
389  if (parser.parseLess())
390  return Type();
391 
392  Type elementType = parseAndVerifyType(dialect, parser);
393  if (!elementType)
394  return Type();
395 
396  unsigned stride = 0;
397  if (failed(parseOptionalArrayStride(dialect, parser, stride)))
398  return Type();
399 
400  if (parser.parseGreater())
401  return Type();
402  return RuntimeArrayType::get(elementType, stride);
403 }
404 
405 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
406 static Type parseMatrixType(SPIRVDialect const &dialect,
407  DialectAsmParser &parser) {
408  if (parser.parseLess())
409  return Type();
410 
411  SmallVector<int64_t, 1> countDims;
412  SMLoc countLoc = parser.getCurrentLocation();
413  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
414  return Type();
415  if (countDims.size() != 1) {
416  parser.emitError(countLoc, "expected single unsigned "
417  "integer for number of columns");
418  return Type();
419  }
420 
421  int64_t columnCount = countDims[0];
422  // According to the specification, Matrices can have 2, 3, or 4 columns
423  if (columnCount < 2 || columnCount > 4) {
424  parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
425  "columns");
426  return Type();
427  }
428 
429  Type columnType = parseAndVerifyMatrixType(dialect, parser);
430  if (!columnType)
431  return Type();
432 
433  if (parser.parseGreater())
434  return Type();
435 
436  return MatrixType::get(columnType, columnCount);
437 }
438 
439 // Specialize this function to parse each of the parameters that define an
440 // ImageType. By default it assumes this is an enum type.
441 template <typename ValTy>
442 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
443  DialectAsmParser &parser) {
444  StringRef enumSpec;
445  SMLoc enumLoc = parser.getCurrentLocation();
446  if (parser.parseKeyword(&enumSpec)) {
447  return llvm::None;
448  }
449 
450  auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
451  if (!val)
452  parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
453  return val;
454 }
455 
456 template <>
457 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
458  DialectAsmParser &parser) {
459  // TODO: Further verify that the element type can be sampled
460  auto ty = parseAndVerifyType(dialect, parser);
461  if (!ty)
462  return llvm::None;
463  return ty;
464 }
465 
466 template <typename IntTy>
467 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
468  DialectAsmParser &parser) {
469  IntTy offsetVal = std::numeric_limits<IntTy>::max();
470  if (parser.parseInteger(offsetVal))
471  return llvm::None;
472  return offsetVal;
473 }
474 
475 template <>
476 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
477  DialectAsmParser &parser) {
478  return parseAndVerifyInteger<unsigned>(dialect, parser);
479 }
480 
481 namespace {
482 // Functor object to parse a comma separated list of specs. The function
483 // parseAndVerify does the actual parsing and verification of individual
484 // elements. This is a functor since parsing the last element of the list
485 // (termination condition) needs partial specialization.
486 template <typename ParseType, typename... Args> struct ParseCommaSeparatedList {
487  Optional<std::tuple<ParseType, Args...>>
488  operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
489  auto parseVal = parseAndVerify<ParseType>(dialect, parser);
490  if (!parseVal)
491  return llvm::None;
492 
493  auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
494  if (numArgs != 0 && failed(parser.parseComma()))
495  return llvm::None;
496  auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
497  if (!remainingValues)
498  return llvm::None;
499  return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
500  remainingValues.getValue());
501  }
502 };
503 
504 // Partial specialization of the function to parse a comma separated list of
505 // specs to parse the last element of the list.
506 template <typename ParseType> struct ParseCommaSeparatedList<ParseType> {
507  Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
508  DialectAsmParser &parser) const {
509  if (auto value = parseAndVerify<ParseType>(dialect, parser))
510  return std::tuple<ParseType>(*value);
511  return llvm::None;
512  }
513 };
514 } // namespace
515 
516 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
517 //
518 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
519 //
520 // arrayed-info ::= `NonArrayed` | `Arrayed`
521 //
522 // sampling-info ::= `SingleSampled` | `MultiSampled`
523 //
524 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
525 //
526 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
527 //
528 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
529 // arrayed-info `,` sampling-info `,`
530 // sampler-use-info `,` format `>`
531 static Type parseImageType(SPIRVDialect const &dialect,
532  DialectAsmParser &parser) {
533  if (parser.parseLess())
534  return Type();
535 
536  auto value =
537  ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
538  ImageSamplingInfo, ImageSamplerUseInfo,
539  ImageFormat>{}(dialect, parser);
540  if (!value)
541  return Type();
542 
543  if (parser.parseGreater())
544  return Type();
545  return ImageType::get(*value);
546 }
547 
548 // sampledImage-type :: = `!spv.sampledImage<` image-type `>`
549 static Type parseSampledImageType(SPIRVDialect const &dialect,
550  DialectAsmParser &parser) {
551  if (parser.parseLess())
552  return Type();
553 
554  Type parsedType = parseAndVerifySampledImageType(dialect, parser);
555  if (!parsedType)
556  return Type();
557 
558  if (parser.parseGreater())
559  return Type();
560  return SampledImageType::get(parsedType);
561 }
562 
563 // Parse decorations associated with a member.
565  SPIRVDialect const &dialect, DialectAsmParser &parser,
566  ArrayRef<Type> memberTypes,
569 
570  // Check if the first element is offset.
571  SMLoc offsetLoc = parser.getCurrentLocation();
572  StructType::OffsetInfo offset = 0;
573  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
574  if (offsetParseResult.hasValue()) {
575  if (failed(*offsetParseResult))
576  return failure();
577 
578  if (offsetInfo.size() != memberTypes.size() - 1) {
579  return parser.emitError(offsetLoc,
580  "offset specification must be given for "
581  "all members");
582  }
583  offsetInfo.push_back(offset);
584  }
585 
586  // Check for no spirv::Decorations.
587  if (succeeded(parser.parseOptionalRSquare()))
588  return success();
589 
590  // If there was an offset, make sure to parse the comma.
591  if (offsetParseResult.hasValue() && parser.parseComma())
592  return failure();
593 
594  // Check for spirv::Decorations.
595  auto parseDecorations = [&]() {
596  auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
597  if (!memberDecoration)
598  return failure();
599 
600  // Parse member decoration value if it exists.
601  if (succeeded(parser.parseOptionalEqual())) {
602  auto memberDecorationValue =
603  parseAndVerifyInteger<uint32_t>(dialect, parser);
604 
605  if (!memberDecorationValue)
606  return failure();
607 
608  memberDecorationInfo.emplace_back(
609  static_cast<uint32_t>(memberTypes.size() - 1), 1,
610  memberDecoration.getValue(), memberDecorationValue.getValue());
611  } else {
612  memberDecorationInfo.emplace_back(
613  static_cast<uint32_t>(memberTypes.size() - 1), 0,
614  memberDecoration.getValue(), 0);
615  }
616  return success();
617  };
618  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
619  failed(parser.parseRSquare()))
620  return failure();
621 
622  return success();
623 }
624 
625 // struct-member-decoration ::= integer-literal? spirv-decoration*
626 // struct-type ::=
627 // `!spv.struct<` (id `,`)?
628 // `(`
629 // (spirv-type (`[` struct-member-decoration `]`)?)*
630 // `)>`
631 static Type parseStructType(SPIRVDialect const &dialect,
632  DialectAsmParser &parser) {
633  // TODO: This function is quite lengthy. Break it down into smaller chunks.
634 
635  // To properly resolve recursive references while parsing recursive struct
636  // types, we need to maintain a list of enclosing struct type names. This set
637  // maintains the names of struct types in which the type we are about to parse
638  // is nested.
639  //
640  // Note: This has to be thread_local to enable multiple threads to safely
641  // parse concurrently.
642  thread_local SetVector<StringRef> structContext;
643 
644  static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
645  StringRef identifier) {
646  if (!identifier.empty())
647  structContext.remove(identifier);
648 
649  return Type();
650  };
651 
652  if (parser.parseLess())
653  return Type();
654 
655  StringRef identifier;
656 
657  // Check if this is an identified struct type.
658  if (succeeded(parser.parseOptionalKeyword(&identifier))) {
659  // Check if this is a possible recursive reference.
660  if (succeeded(parser.parseOptionalGreater())) {
661  if (structContext.count(identifier) == 0) {
662  parser.emitError(
663  parser.getNameLoc(),
664  "recursive struct reference not nested in struct definition");
665 
666  return Type();
667  }
668 
669  return StructType::getIdentified(dialect.getContext(), identifier);
670  }
671 
672  if (failed(parser.parseComma()))
673  return Type();
674 
675  if (structContext.count(identifier) != 0) {
676  parser.emitError(parser.getNameLoc(),
677  "identifier already used for an enclosing struct");
678 
679  return removeIdentifierAndFail(structContext, identifier);
680  }
681 
682  structContext.insert(identifier);
683  }
684 
685  if (failed(parser.parseLParen()))
686  return removeIdentifierAndFail(structContext, identifier);
687 
688  if (succeeded(parser.parseOptionalRParen()) &&
689  succeeded(parser.parseOptionalGreater())) {
690  if (!identifier.empty())
691  structContext.remove(identifier);
692 
693  return StructType::getEmpty(dialect.getContext(), identifier);
694  }
695 
696  StructType idStructTy;
697 
698  if (!identifier.empty())
699  idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
700 
701  SmallVector<Type, 4> memberTypes;
704 
705  do {
706  Type memberType;
707  if (parser.parseType(memberType))
708  return removeIdentifierAndFail(structContext, identifier);
709  memberTypes.push_back(memberType);
710 
711  if (succeeded(parser.parseOptionalLSquare()))
712  if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
713  memberDecorationInfo))
714  return removeIdentifierAndFail(structContext, identifier);
715  } while (succeeded(parser.parseOptionalComma()));
716 
717  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
718  parser.emitError(parser.getNameLoc(),
719  "offset specification must be given for all members");
720  return removeIdentifierAndFail(structContext, identifier);
721  }
722 
723  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
724  return removeIdentifierAndFail(structContext, identifier);
725 
726  if (!identifier.empty()) {
727  if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
728  memberDecorationInfo)))
729  return Type();
730 
731  structContext.remove(identifier);
732  return idStructTy;
733  }
734 
735  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
736 }
737 
738 // spirv-type ::= array-type
739 // | element-type
740 // | image-type
741 // | pointer-type
742 // | runtime-array-type
743 // | sampled-image-type
744 // | struct-type
746  StringRef keyword;
747  if (parser.parseKeyword(&keyword))
748  return Type();
749 
750  if (keyword == "array")
751  return parseArrayType(*this, parser);
752  if (keyword == "coopmatrix")
753  return parseCooperativeMatrixType(*this, parser);
754  if (keyword == "image")
755  return parseImageType(*this, parser);
756  if (keyword == "ptr")
757  return parsePointerType(*this, parser);
758  if (keyword == "rtarray")
759  return parseRuntimeArrayType(*this, parser);
760  if (keyword == "sampled_image")
761  return parseSampledImageType(*this, parser);
762  if (keyword == "struct")
763  return parseStructType(*this, parser);
764  if (keyword == "matrix")
765  return parseMatrixType(*this, parser);
766  parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
767  return Type();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // Type Printing
772 //===----------------------------------------------------------------------===//
773 
774 static void print(ArrayType type, DialectAsmPrinter &os) {
775  os << "array<" << type.getNumElements() << " x " << type.getElementType();
776  if (unsigned stride = type.getArrayStride())
777  os << ", stride=" << stride;
778  os << ">";
779 }
780 
781 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
782  os << "rtarray<" << type.getElementType();
783  if (unsigned stride = type.getArrayStride())
784  os << ", stride=" << stride;
785  os << ">";
786 }
787 
788 static void print(PointerType type, DialectAsmPrinter &os) {
789  os << "ptr<" << type.getPointeeType() << ", "
790  << stringifyStorageClass(type.getStorageClass()) << ">";
791 }
792 
793 static void print(ImageType type, DialectAsmPrinter &os) {
794  os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
795  << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
796  << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
797  << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
798  << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
799  << stringifyImageFormat(type.getImageFormat()) << ">";
800 }
801 
802 static void print(SampledImageType type, DialectAsmPrinter &os) {
803  os << "sampled_image<" << type.getImageType() << ">";
804 }
805 
806 static void print(StructType type, DialectAsmPrinter &os) {
807  thread_local SetVector<StringRef> structContext;
808 
809  os << "struct<";
810 
811  if (type.isIdentified()) {
812  os << type.getIdentifier();
813 
814  if (structContext.count(type.getIdentifier())) {
815  os << ">";
816  return;
817  }
818 
819  os << ", ";
820  structContext.insert(type.getIdentifier());
821  }
822 
823  os << "(";
824 
825  auto printMember = [&](unsigned i) {
826  os << type.getElementType(i);
828  type.getMemberDecorations(i, decorations);
829  if (type.hasOffset() || !decorations.empty()) {
830  os << " [";
831  if (type.hasOffset()) {
832  os << type.getMemberOffset(i);
833  if (!decorations.empty())
834  os << ", ";
835  }
836  auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
837  os << stringifyDecoration(decoration.decoration);
838  if (decoration.hasValue) {
839  os << "=" << decoration.decorationValue;
840  }
841  };
842  llvm::interleaveComma(decorations, os, eachFn);
843  os << "]";
844  }
845  };
846  llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
847  printMember);
848  os << ")>";
849 
850  if (type.isIdentified())
851  structContext.remove(type.getIdentifier());
852 }
853 
855  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
856  os << type.getElementType() << ", " << stringifyScope(type.getScope());
857  os << ">";
858 }
859 
860 static void print(MatrixType type, DialectAsmPrinter &os) {
861  os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
862  os << ">";
863 }
864 
865 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
866  TypeSwitch<Type>(type)
869  [&](auto type) { print(type, os); })
870  .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // Constant
875 //===----------------------------------------------------------------------===//
876 
878  Attribute value, Type type,
879  Location loc) {
880  if (!spirv::ConstantOp::isBuildableWith(type))
881  return nullptr;
882 
883  return builder.create<spirv::ConstantOp>(loc, type, value);
884 }
885 
886 //===----------------------------------------------------------------------===//
887 // Shader Interface ABI
888 //===----------------------------------------------------------------------===//
889 
890 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
891  NamedAttribute attribute) {
892  StringRef symbol = attribute.getName().strref();
893  Attribute attr = attribute.getValue();
894 
895  if (symbol == spirv::getEntryPointABIAttrName()) {
896  if (!attr.isa<spirv::EntryPointABIAttr>()) {
897  return op->emitError("'")
898  << symbol << "' attribute must be an entry point ABI attribute";
899  }
900  } else if (symbol == spirv::getTargetEnvAttrName()) {
901  if (!attr.isa<spirv::TargetEnvAttr>())
902  return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
903  } else {
904  return op->emitError("found unsupported '")
905  << symbol << "' attribute on operation";
906  }
907 
908  return success();
909 }
910 
911 /// Verifies the given SPIR-V `attribute` attached to a value of the given
912 /// `valueType` is valid.
914  NamedAttribute attribute) {
915  StringRef symbol = attribute.getName().strref();
916  Attribute attr = attribute.getValue();
917 
918  if (symbol != spirv::getInterfaceVarABIAttrName())
919  return emitError(loc, "found unsupported '")
920  << symbol << "' attribute on region argument";
921 
922  auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
923  if (!varABIAttr)
924  return emitError(loc, "'")
925  << symbol << "' must be a spirv::InterfaceVarABIAttr";
926 
927  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
928  return emitError(loc, "'") << symbol
929  << "' attribute cannot specify storage class "
930  "when attaching to a non-scalar value";
931 
932  return success();
933 }
934 
935 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
936  unsigned regionIndex,
937  unsigned argIndex,
938  NamedAttribute attribute) {
939  return verifyRegionAttribute(
940  op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
941  attribute);
942 }
943 
944 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
945  Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
946  NamedAttribute attribute) {
947  return op->emitError("cannot attach SPIR-V attributes to region result");
948 }
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
ImageFormat getImageFormat() const
Definition: SPIRVTypes.cpp:349
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:158
Optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:64
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
Block represents an ordered list of Operations.
Definition: Block.h:29
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
An attribute that specifies the information regarding the interface variable: descriptor set...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
bool isa() const
Definition: Attributes.h:109
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
bool isIdentified() const
Returns true if the StructType is identified.
Definition: SPIRVTypes.cpp:991
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:448
Type getElementType() const
Definition: SPIRVTypes.cpp:331
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:491
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
static Optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual ParseResult parseComma()=0
Parse a , token.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:144
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:952
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:337
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:723
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
static Optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
Definition: SPIRVTypes.cpp:222
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:48
U dyn_cast() const
Definition: Types.h:256
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:345
Type getElementType() const
Definition: SPIRVTypes.cpp:62
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:32
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
virtual ParseResult parseRParen()=0
Parse a ) token.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:40
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
ImageDepthInfo getDepthInfo() const
Definition: SPIRVTypes.cpp:335
StringRef getIdentifier() const
For literal structs, return an empty string.
Definition: SPIRVTypes.cpp:989
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
Optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
uint64_t getMemberOffset(unsigned) const
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
Definition: SPIRVTypes.cpp:975
virtual ParseResult parseRSquare()=0
Parse a ] token.
static void print(ArrayType type, DialectAsmPrinter &os)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:993
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
virtual ParseResult parseOptionalGreater()=0
Parse a &#39;>&#39; token if present.
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:995
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:458
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
Type getType() const
Return the type of this value.
Definition: Value.h:118
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
U dyn_cast() const
Definition: Attributes.h:124
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:163
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual ParseResult parseType(Type &result)=0
Parse a type.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid...
SPIR-V struct type.
Definition: SPIRVTypes.h:278
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:397
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
virtual ParseResult parseEqual()=0
Parse a = token.
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:232
unsigned getNumColumns() const
Returns the number of columns.
static MatrixType get(Type columnType, uint32_t columnCount)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
Definition: SPIRVTypes.cpp:965
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
bool isa() const
Definition: Types.h:246
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class helps build Operations.
Definition: Builders.h:184
Type getColumnType() const
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
static bool containsReturn(Region &region)
Returns true if the given region contains spv.Return or spv.ReturnValue ops.
An attribute that specifies the target version, allowed extensions and capabilities, and resource limits.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
bool isBF16() const
Definition: Types.cpp:21