MLIR  17.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include <cassert>
37 #include <numeric>
38 
39 using namespace mlir;
40 
41 // TODO: generate these strings using ODS.
42 constexpr char kAlignmentAttrName[] = "alignment";
43 constexpr char kBranchWeightAttrName[] = "branch_weights";
44 constexpr char kCallee[] = "callee";
45 constexpr char kClusterSize[] = "cluster_size";
46 constexpr char kControl[] = "control";
47 constexpr char kDefaultValueAttrName[] = "default_value";
48 constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
49 constexpr char kExecutionScopeAttrName[] = "execution_scope";
50 constexpr char kFnNameAttrName[] = "fn";
51 constexpr char kGroupOperationAttrName[] = "group_operation";
52 constexpr char kIndicesAttrName[] = "indices";
53 constexpr char kInitializerAttrName[] = "initializer";
54 constexpr char kInterfaceAttrName[] = "interface";
55 constexpr char kMemoryAccessAttrName[] = "memory_access";
56 constexpr char kMemoryScopeAttrName[] = "memory_scope";
57 constexpr char kPackedVectorFormatAttrName[] = "format";
58 constexpr char kSemanticsAttrName[] = "semantics";
59 constexpr char kSourceAlignmentAttrName[] = "source_alignment";
60 constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
61 constexpr char kSpecIdAttrName[] = "spec_id";
62 constexpr char kTypeAttrName[] = "type";
63 constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
64 constexpr char kValueAttrName[] = "value";
65 constexpr char kValuesAttrName[] = "values";
66 constexpr char kCompositeSpecConstituentsName[] = "constituents";
67 
68 //===----------------------------------------------------------------------===//
69 // Common utility functions
70 //===----------------------------------------------------------------------===//
71 
73  OperationState &result) {
75  Type type;
76  // If the operand list is in-between parentheses, then we have a generic form.
77  // (see the fallback in `printOneResultOp`).
78  SMLoc loc = parser.getCurrentLocation();
79  if (!parser.parseOptionalLParen()) {
80  if (parser.parseOperandList(ops) || parser.parseRParen() ||
81  parser.parseOptionalAttrDict(result.attributes) ||
82  parser.parseColon() || parser.parseType(type))
83  return failure();
84  auto fnType = type.dyn_cast<FunctionType>();
85  if (!fnType) {
86  parser.emitError(loc, "expected function type");
87  return failure();
88  }
89  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
90  return failure();
91  result.addTypes(fnType.getResults());
92  return success();
93  }
94  return failure(parser.parseOperandList(ops) ||
95  parser.parseOptionalAttrDict(result.attributes) ||
96  parser.parseColonType(type) ||
97  parser.resolveOperands(ops, type, result.operands) ||
98  parser.addTypeToList(type, result.types));
99 }
100 
102  assert(op->getNumResults() == 1 && "op should have one result");
103 
104  // If not all the operand and result types are the same, just use the
105  // generic assembly form to avoid omitting information in printing.
106  auto resultType = op->getResult(0).getType();
107  if (llvm::any_of(op->getOperandTypes(),
108  [&](Type type) { return type != resultType; })) {
109  p.printGenericOp(op, /*printOpName=*/false);
110  return;
111  }
112 
113  p << ' ';
114  p.printOperands(op->getOperands());
116  // Now we can output only one type for all operands and the result.
117  p << " : " << resultType;
118 }
119 
120 /// Returns true if the given op is a function-like op or nested in a
121 /// function-like op without a module-like op in the middle.
123  if (!op)
124  return false;
125  if (op->hasTrait<OpTrait::SymbolTable>())
126  return false;
127  if (isa<FunctionOpInterface>(op))
128  return true;
130 }
131 
132 /// Returns true if the given op is an module-like op that maintains a symbol
133 /// table.
135  return op && op->hasTrait<OpTrait::SymbolTable>();
136 }
137 
138 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
139  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
140  if (!constOp) {
141  return failure();
142  }
143  auto valueAttr = constOp.getValue();
144  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
145  if (!integerValueAttr) {
146  return failure();
147  }
148 
149  if (integerValueAttr.getType().isSignlessInteger())
150  value = integerValueAttr.getInt();
151  else
152  value = integerValueAttr.getSInt();
153 
154  return success();
155 }
156 
157 template <typename Ty>
158 static ArrayAttr
160  function_ref<StringRef(Ty)> stringifyFn) {
161  if (enumValues.empty()) {
162  return nullptr;
163  }
164  SmallVector<StringRef, 1> enumValStrs;
165  enumValStrs.reserve(enumValues.size());
166  for (auto val : enumValues) {
167  enumValStrs.emplace_back(stringifyFn(val));
168  }
169  return builder.getStrArrayAttr(enumValStrs);
170 }
171 
172 /// Parses the next string attribute in `parser` as an enumerant of the given
173 /// `EnumClass`.
174 template <typename EnumClass>
175 static ParseResult
176 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
177  StringRef attrName = spirv::attributeName<EnumClass>()) {
178  Attribute attrVal;
179  NamedAttrList attr;
180  auto loc = parser.getCurrentLocation();
181  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
182  attrName, attr))
183  return failure();
184  if (!attrVal.isa<StringAttr>())
185  return parser.emitError(loc, "expected ")
186  << attrName << " attribute specified as string";
187  auto attrOptional =
188  spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
189  if (!attrOptional)
190  return parser.emitError(loc, "invalid ")
191  << attrName << " attribute specification: " << attrVal;
192  value = *attrOptional;
193  return success();
194 }
195 
196 /// Parses the next string attribute in `parser` as an enumerant of the given
197 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
198 /// attribute with the enum class's name as attribute name.
199 template <typename EnumAttrClass,
200  typename EnumClass = typename EnumAttrClass::ValueType>
201 static ParseResult
202 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
203  StringRef attrName = spirv::attributeName<EnumClass>()) {
204  if (parseEnumStrAttr(value, parser))
205  return failure();
206  state.addAttribute(attrName,
207  parser.getBuilder().getAttr<EnumAttrClass>(value));
208  return success();
209 }
210 
211 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
212 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
213 /// the enum class's name as attribute name.
214 template <typename EnumAttrClass,
215  typename EnumClass = typename EnumAttrClass::ValueType>
216 static ParseResult
217 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
218  OperationState &state,
219  StringRef attrName = spirv::attributeName<EnumClass>()) {
220  if (parseEnumKeywordAttr(value, parser))
221  return failure();
222  state.addAttribute(attrName,
223  parser.getBuilder().getAttr<EnumAttrClass>(value));
224  return success();
225 }
226 
227 /// Parses Function, Selection and Loop control attributes. If no control is
228 /// specified, "None" is used as a default.
229 template <typename EnumAttrClass, typename EnumClass>
230 static ParseResult
232  StringRef attrName = spirv::attributeName<EnumClass>()) {
233  if (succeeded(parser.parseOptionalKeyword(kControl))) {
234  EnumClass control;
235  if (parser.parseLParen() ||
236  parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
237  parser.parseRParen())
238  return failure();
239  return success();
240  }
241  // Set control to "None" otherwise.
242  Builder builder = parser.getBuilder();
243  state.addAttribute(attrName,
244  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
245  return success();
246 }
247 
248 /// Parses optional memory access attributes attached to a memory access
249 /// operand/pointer. Specifically, parses the following syntax:
250 /// (`[` memory-access `]`)?
251 /// where:
252 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
253 /// integer-literal | `"NonTemporal"`
255  OperationState &state) {
256  // Parse an optional list of attributes staring with '['
257  if (parser.parseOptionalLSquare()) {
258  // Nothing to do
259  return success();
260  }
261 
262  spirv::MemoryAccess memoryAccessAttr;
263  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
265  return failure();
266 
267  if (spirv::bitEnumContainsAll(memoryAccessAttr,
268  spirv::MemoryAccess::Aligned)) {
269  // Parse integer attribute for alignment.
270  Attribute alignmentAttr;
271  Type i32Type = parser.getBuilder().getIntegerType(32);
272  if (parser.parseComma() ||
273  parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
274  state.attributes)) {
275  return failure();
276  }
277  }
278  return parser.parseRSquare();
279 }
280 
281 // TODO Make sure to merge this and the previous function into one template
282 // parameterized by memory access attribute name and alignment. Doing so now
283 // results in VS2017 in producing an internal error (at the call site) that's
284 // not detailed enough to understand what is happening.
286  OperationState &state) {
287  // Parse an optional list of attributes staring with '['
288  if (parser.parseOptionalLSquare()) {
289  // Nothing to do
290  return success();
291  }
292 
293  spirv::MemoryAccess memoryAccessAttr;
294  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
296  return failure();
297 
298  if (spirv::bitEnumContainsAll(memoryAccessAttr,
299  spirv::MemoryAccess::Aligned)) {
300  // Parse integer attribute for alignment.
301  Attribute alignmentAttr;
302  Type i32Type = parser.getBuilder().getIntegerType(32);
303  if (parser.parseComma() ||
304  parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
305  state.attributes)) {
306  return failure();
307  }
308  }
309  return parser.parseRSquare();
310 }
311 
312 template <typename MemoryOpTy>
314  MemoryOpTy memoryOp, OpAsmPrinter &printer,
315  SmallVectorImpl<StringRef> &elidedAttrs,
316  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
317  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
318  // Print optional memory access attribute.
319  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
320  : memoryOp.getMemoryAccess())) {
321  elidedAttrs.push_back(kMemoryAccessAttrName);
322 
323  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
324 
325  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
326  // Print integer alignment attribute.
327  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
328  : memoryOp.getAlignment())) {
329  elidedAttrs.push_back(kAlignmentAttrName);
330  printer << ", " << *alignment;
331  }
332  }
333  printer << "]";
334  }
335  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
336 }
337 
338 // TODO Make sure to merge this and the previous function into one template
339 // parameterized by memory access attribute name and alignment. Doing so now
340 // results in VS2017 in producing an internal error (at the call site) that's
341 // not detailed enough to understand what is happening.
342 template <typename MemoryOpTy>
344  MemoryOpTy memoryOp, OpAsmPrinter &printer,
345  SmallVectorImpl<StringRef> &elidedAttrs,
346  std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
347  std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
348 
349  printer << ", ";
350 
351  // Print optional memory access attribute.
352  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
353  : memoryOp.getMemoryAccess())) {
354  elidedAttrs.push_back(kSourceMemoryAccessAttrName);
355 
356  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
357 
358  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
359  // Print integer alignment attribute.
360  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
361  : memoryOp.getAlignment())) {
362  elidedAttrs.push_back(kSourceAlignmentAttrName);
363  printer << ", " << *alignment;
364  }
365  }
366  printer << "]";
367  }
368  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
369 }
370 
372  spirv::ImageOperandsAttr &attr) {
373  // Expect image operands
374  if (parser.parseOptionalLSquare())
375  return success();
376 
377  spirv::ImageOperands imageOperands;
378  if (parseEnumStrAttr(imageOperands, parser))
379  return failure();
380 
381  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
382 
383  return parser.parseRSquare();
384 }
385 
386 static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
387  spirv::ImageOperandsAttr attr) {
388  if (attr) {
389  auto strImageOperands = stringifyImageOperands(attr.getValue());
390  printer << "[\"" << strImageOperands << "\"]";
391  }
392 }
393 
394 template <typename Op>
396  spirv::ImageOperandsAttr attr,
397  Operation::operand_range operands) {
398  if (!attr) {
399  if (operands.empty())
400  return success();
401 
402  return imageOp.emitError("the Image Operands should encode what operands "
403  "follow, as per Image Operands");
404  }
405 
406  // TODO: Add the validation rules for the following Image Operands.
407  spirv::ImageOperands noSupportOperands =
408  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
409  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
410  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
411  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
412  spirv::ImageOperands::MakeTexelAvailable |
413  spirv::ImageOperands::MakeTexelVisible |
414  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
415 
416  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
417  llvm_unreachable("unimplemented operands of Image Operands");
418 
419  return success();
420 }
421 
423  bool requireSameBitWidth = true,
424  bool skipBitWidthCheck = false) {
425  // Some CastOps have no limit on bit widths for result and operand type.
426  if (skipBitWidthCheck)
427  return success();
428 
429  Type operandType = op->getOperand(0).getType();
430  Type resultType = op->getResult(0).getType();
431 
432  // ODS checks that result type and operand type have the same shape.
433  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
434  operandType = vectorType.getElementType();
435  resultType = resultType.cast<VectorType>().getElementType();
436  }
437 
438  if (auto coopMatrixType =
439  operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
440  operandType = coopMatrixType.getElementType();
441  resultType =
443  }
444 
445  if (auto jointMatrixType =
446  operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
447  operandType = jointMatrixType.getElementType();
448  resultType =
450  }
451 
452  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
453  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
454  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
455 
456  if (requireSameBitWidth) {
457  if (!isSameBitWidth) {
458  return op->emitOpError(
459  "expected the same bit widths for operand type and result "
460  "type, but provided ")
461  << operandType << " and " << resultType;
462  }
463  return success();
464  }
465 
466  if (isSameBitWidth) {
467  return op->emitOpError(
468  "expected the different bit widths for operand type and result "
469  "type, but provided ")
470  << operandType << " and " << resultType;
471  }
472  return success();
473 }
474 
475 template <typename MemoryOpTy>
476 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
477  // ODS checks for attributes values. Just need to verify that if the
478  // memory-access attribute is Aligned, then the alignment attribute must be
479  // present.
480  auto *op = memoryOp.getOperation();
481  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
482  if (!memAccessAttr) {
483  // Alignment attribute shouldn't be present if memory access attribute is
484  // not present.
485  if (op->getAttr(kAlignmentAttrName)) {
486  return memoryOp.emitOpError(
487  "invalid alignment specification without aligned memory access "
488  "specification");
489  }
490  return success();
491  }
492 
493  auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
494 
495  if (!memAccess) {
496  return memoryOp.emitOpError("invalid memory access specifier: ")
497  << memAccessAttr;
498  }
499 
500  if (spirv::bitEnumContainsAll(memAccess.getValue(),
501  spirv::MemoryAccess::Aligned)) {
502  if (!op->getAttr(kAlignmentAttrName)) {
503  return memoryOp.emitOpError("missing alignment value");
504  }
505  } else {
506  if (op->getAttr(kAlignmentAttrName)) {
507  return memoryOp.emitOpError(
508  "invalid alignment specification with non-aligned memory access "
509  "specification");
510  }
511  }
512  return success();
513 }
514 
515 // TODO Make sure to merge this and the previous function into one template
516 // parameterized by memory access attribute name and alignment. Doing so now
517 // results in VS2017 in producing an internal error (at the call site) that's
518 // not detailed enough to understand what is happening.
519 template <typename MemoryOpTy>
521  // ODS checks for attributes values. Just need to verify that if the
522  // memory-access attribute is Aligned, then the alignment attribute must be
523  // present.
524  auto *op = memoryOp.getOperation();
525  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
526  if (!memAccessAttr) {
527  // Alignment attribute shouldn't be present if memory access attribute is
528  // not present.
529  if (op->getAttr(kSourceAlignmentAttrName)) {
530  return memoryOp.emitOpError(
531  "invalid alignment specification without aligned memory access "
532  "specification");
533  }
534  return success();
535  }
536 
537  auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
538 
539  if (!memAccess) {
540  return memoryOp.emitOpError("invalid memory access specifier: ")
541  << memAccess;
542  }
543 
544  if (spirv::bitEnumContainsAll(memAccess.getValue(),
545  spirv::MemoryAccess::Aligned)) {
546  if (!op->getAttr(kSourceAlignmentAttrName)) {
547  return memoryOp.emitOpError("missing alignment value");
548  }
549  } else {
550  if (op->getAttr(kSourceAlignmentAttrName)) {
551  return memoryOp.emitOpError(
552  "invalid alignment specification with non-aligned memory access "
553  "specification");
554  }
555  }
556  return success();
557 }
558 
559 static LogicalResult
560 verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
561  // According to the SPIR-V specification:
562  // "Despite being a mask and allowing multiple bits to be combined, it is
563  // invalid for more than one of these four bits to be set: Acquire, Release,
564  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
565  // Release semantics is done by setting the AcquireRelease bit, not by setting
566  // two bits."
567  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
568  spirv::MemorySemantics::Release |
569  spirv::MemorySemantics::AcquireRelease |
570  spirv::MemorySemantics::SequentiallyConsistent;
571 
572  auto bitCount =
573  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
574  if (bitCount > 1) {
575  return op->emitError(
576  "expected at most one of these four memory constraints "
577  "to be set: `Acquire`, `Release`,"
578  "`AcquireRelease` or `SequentiallyConsistent`");
579  }
580  return success();
581 }
582 
583 template <typename LoadStoreOpTy>
584 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
585  Value val) {
586  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
587  // type of the pointer and the type of the value are the same
588  //
589  // TODO: Check that the value type satisfies restrictions of
590  // SPIR-V OpLoad/OpStore operations
591  if (val.getType() !=
593  return op.emitOpError("mismatch in result type and pointer type");
594  }
595  return success();
596 }
597 
598 template <typename BlockReadWriteOpTy>
600  Value ptr, Value val) {
601  auto valType = val.getType();
602  if (auto valVecTy = valType.dyn_cast<VectorType>())
603  valType = valVecTy.getElementType();
604 
605  if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
606  return op.emitOpError("mismatch in result type and pointer type");
607  }
608  return success();
609 }
610 
612  OperationState &state) {
613  auto builtInName = llvm::convertToSnakeFromCamelCase(
614  stringifyDecoration(spirv::Decoration::BuiltIn));
615  if (succeeded(parser.parseOptionalKeyword("bind"))) {
616  Attribute set, binding;
617  // Parse optional descriptor binding
618  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
619  stringifyDecoration(spirv::Decoration::DescriptorSet));
620  auto bindingName = llvm::convertToSnakeFromCamelCase(
621  stringifyDecoration(spirv::Decoration::Binding));
622  Type i32Type = parser.getBuilder().getIntegerType(32);
623  if (parser.parseLParen() ||
624  parser.parseAttribute(set, i32Type, descriptorSetName,
625  state.attributes) ||
626  parser.parseComma() ||
627  parser.parseAttribute(binding, i32Type, bindingName,
628  state.attributes) ||
629  parser.parseRParen()) {
630  return failure();
631  }
632  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
633  StringAttr builtIn;
634  if (parser.parseLParen() ||
635  parser.parseAttribute(builtIn, builtInName, state.attributes) ||
636  parser.parseRParen()) {
637  return failure();
638  }
639  }
640 
641  // Parse other attributes
642  if (parser.parseOptionalAttrDict(state.attributes))
643  return failure();
644 
645  return success();
646 }
647 
649  SmallVectorImpl<StringRef> &elidedAttrs) {
650  // Print optional descriptor binding
651  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
652  stringifyDecoration(spirv::Decoration::DescriptorSet));
653  auto bindingName = llvm::convertToSnakeFromCamelCase(
654  stringifyDecoration(spirv::Decoration::Binding));
655  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
656  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
657  if (descriptorSet && binding) {
658  elidedAttrs.push_back(descriptorSetName);
659  elidedAttrs.push_back(bindingName);
660  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
661  << ")";
662  }
663 
664  // Print BuiltIn attribute if present
665  auto builtInName = llvm::convertToSnakeFromCamelCase(
666  stringifyDecoration(spirv::Decoration::BuiltIn));
667  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
668  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
669  elidedAttrs.push_back(builtInName);
670  }
671 
672  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
673 }
674 
675 // Get bit width of types.
676 static unsigned getBitWidth(Type type) {
677  if (type.isa<spirv::PointerType>()) {
678  // Just return 64 bits for pointer types for now.
679  // TODO: Make sure not caller relies on the actual pointer width value.
680  return 64;
681  }
682 
683  if (type.isIntOrFloat())
684  return type.getIntOrFloatBitWidth();
685 
686  if (auto vectorType = type.dyn_cast<VectorType>()) {
687  assert(vectorType.getElementType().isIntOrFloat());
688  return vectorType.getNumElements() *
689  vectorType.getElementType().getIntOrFloatBitWidth();
690  }
691  llvm_unreachable("unhandled bit width computation for type");
692 }
693 
694 /// Walks the given type hierarchy with the given indices, potentially down
695 /// to component granularity, to select an element type. Returns null type and
696 /// emits errors with the given loc on failure.
697 static Type
699  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
700  if (indices.empty()) {
701  emitErrorFn("expected at least one index for spirv.CompositeExtract");
702  return nullptr;
703  }
704 
705  for (auto index : indices) {
706  if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
707  if (cType.hasCompileTimeKnownNumElements() &&
708  (index < 0 ||
709  static_cast<uint64_t>(index) >= cType.getNumElements())) {
710  emitErrorFn("index ") << index << " out of bounds for " << type;
711  return nullptr;
712  }
713  type = cType.getElementType(index);
714  } else {
715  emitErrorFn("cannot extract from non-composite type ")
716  << type << " with index " << index;
717  return nullptr;
718  }
719  }
720  return type;
721 }
722 
723 static Type
725  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
726  auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
727  if (!indicesArrayAttr) {
728  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
729  return nullptr;
730  }
731  if (indicesArrayAttr.empty()) {
732  emitErrorFn("expected at least one index for spirv.CompositeExtract");
733  return nullptr;
734  }
735 
736  SmallVector<int32_t, 2> indexVals;
737  for (auto indexAttr : indicesArrayAttr) {
738  auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
739  if (!indexIntAttr) {
740  emitErrorFn("expected an 32-bit integer for index, but found '")
741  << indexAttr << "'";
742  return nullptr;
743  }
744  indexVals.push_back(indexIntAttr.getInt());
745  }
746  return getElementType(type, indexVals, emitErrorFn);
747 }
748 
749 static Type getElementType(Type type, Attribute indices, Location loc) {
750  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
751  return ::mlir::emitError(loc, err);
752  };
753  return getElementType(type, indices, errorFn);
754 }
755 
756 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
757  SMLoc loc) {
758  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
759  return parser.emitError(loc, err);
760  };
761  return getElementType(type, indices, errorFn);
762 }
763 
764 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
765 static inline bool isMergeBlock(Block &block) {
766  return !block.empty() && std::next(block.begin()) == block.end() &&
767  isa<spirv::MergeOp>(block.front());
768 }
769 
770 template <typename ExtendedBinaryOp>
771 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
772  auto resultType = op.getType().template cast<spirv::StructType>();
773  if (resultType.getNumElements() != 2)
774  return op.emitOpError("expected result struct type containing two members");
775 
776  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
777  resultType.getElementType(0),
778  resultType.getElementType(1)}))
779  return op.emitOpError(
780  "expected all operand types and struct member types are the same");
781 
782  return success();
783 }
784 
786  OperationState &result) {
788  if (parser.parseOptionalAttrDict(result.attributes) ||
789  parser.parseOperandList(operands) || parser.parseColon())
790  return failure();
791 
792  Type resultType;
793  SMLoc loc = parser.getCurrentLocation();
794  if (parser.parseType(resultType))
795  return failure();
796 
797  auto structType = resultType.dyn_cast<spirv::StructType>();
798  if (!structType || structType.getNumElements() != 2)
799  return parser.emitError(loc, "expected spirv.struct type with two members");
800 
801  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
802  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
803  return failure();
804 
805  result.addTypes(resultType);
806  return success();
807 }
808 
810  OpAsmPrinter &printer) {
811  printer << ' ';
812  printer.printOptionalAttrDict(op->getAttrs());
813  printer.printOperands(op->getOperands());
814  printer << " : " << op->getResultTypes().front();
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // Common parsers and printers
819 //===----------------------------------------------------------------------===//
820 
821 // Parses an atomic update op. If the update op does not take a value (like
822 // AtomicIIncrement) `hasValue` must be false.
824  OperationState &state, bool hasValue) {
825  spirv::Scope scope;
826  spirv::MemorySemantics memoryScope;
828  OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
829  Type type;
830  SMLoc loc;
831  if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
833  parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
835  parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
836  parser.getCurrentLocation(&loc) || parser.parseColonType(type))
837  return failure();
838 
839  auto ptrType = type.dyn_cast<spirv::PointerType>();
840  if (!ptrType)
841  return parser.emitError(loc, "expected pointer type");
842 
843  SmallVector<Type, 2> operandTypes;
844  operandTypes.push_back(ptrType);
845  if (hasValue)
846  operandTypes.push_back(ptrType.getPointeeType());
847  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
848  state.operands))
849  return failure();
850  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
851 }
852 
853 // Prints an atomic update op.
854 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
855  printer << " \"";
856  auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
857  printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
858  auto memorySemanticsAttr =
859  op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
860  printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
861  << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
862 }
863 
864 template <typename T>
865 static StringRef stringifyTypeName();
866 
867 template <>
869  return "integer";
870 }
871 
872 template <>
874  return "float";
875 }
876 
877 // Verifies an atomic update op.
878 template <typename ExpectedElementType>
880  auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
881  auto elementType = ptrType.getPointeeType();
882  if (!elementType.isa<ExpectedElementType>())
883  return op->emitOpError() << "pointer operand must point to an "
884  << stringifyTypeName<ExpectedElementType>()
885  << " value, found " << elementType;
886 
887  if (op->getNumOperands() > 1) {
888  auto valueType = op->getOperand(1).getType();
889  if (valueType != elementType)
890  return op->emitOpError("expected value to have the same type as the "
891  "pointer operand's pointee type ")
892  << elementType << ", but found " << valueType;
893  }
894  auto memorySemantics =
895  op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
896  .getValue();
897  if (failed(verifyMemorySemantics(op, memorySemantics))) {
898  return failure();
899  }
900  return success();
901 }
902 
904  OperationState &state) {
905  spirv::Scope executionScope;
906  spirv::GroupOperation groupOperation;
908  if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
910  parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
912  parser.parseOperand(valueInfo))
913  return failure();
914 
915  std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
917  clusterSizeInfo = OpAsmParser::UnresolvedOperand();
918  if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
919  parser.parseRParen())
920  return failure();
921  }
922 
923  Type resultType;
924  if (parser.parseColonType(resultType))
925  return failure();
926 
927  if (parser.resolveOperand(valueInfo, resultType, state.operands))
928  return failure();
929 
930  if (clusterSizeInfo) {
931  Type i32Type = parser.getBuilder().getIntegerType(32);
932  if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
933  return failure();
934  }
935 
936  return parser.addTypeToList(resultType, state.types);
937 }
938 
940  OpAsmPrinter &printer) {
941  printer
942  << " \""
943  << stringifyScope(
944  groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
945  .getValue())
946  << "\" \""
947  << stringifyGroupOperation(groupOp
948  ->getAttrOfType<spirv::GroupOperationAttr>(
950  .getValue())
951  << "\" " << groupOp->getOperand(0);
952 
953  if (groupOp->getNumOperands() > 1)
954  printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
955  printer << " : " << groupOp->getResult(0).getType();
956 }
957 
959  spirv::Scope scope =
960  groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
961  .getValue();
962  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
963  return groupOp->emitOpError(
964  "execution scope must be 'Workgroup' or 'Subgroup'");
965 
966  spirv::GroupOperation operation =
967  groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
968  .getValue();
969  if (operation == spirv::GroupOperation::ClusteredReduce &&
970  groupOp->getNumOperands() == 1)
971  return groupOp->emitOpError("cluster size operand must be provided for "
972  "'ClusteredReduce' group operation");
973  if (groupOp->getNumOperands() > 1) {
974  Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
975  int32_t clusterSize = 0;
976 
977  // TODO: support specialization constant here.
978  if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
979  return groupOp->emitOpError(
980  "cluster size operand must come from a constant op");
981 
982  if (!llvm::isPowerOf2_32(clusterSize))
983  return groupOp->emitOpError(
984  "cluster size operand must be a power of two");
985  }
986  return success();
987 }
988 
989 /// Result of a logical op must be a scalar or vector of boolean type.
990 static Type getUnaryOpResultType(Type operandType) {
991  Builder builder(operandType.getContext());
992  Type resultType = builder.getIntegerType(1);
993  if (auto vecType = operandType.dyn_cast<VectorType>())
994  return VectorType::get(vecType.getNumElements(), resultType);
995  return resultType;
996 }
997 
999  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
1000  return op->emitError("expected the same type for the first operand and "
1001  "result, but provided ")
1002  << op->getOperand(0).getType() << " and "
1003  << op->getResult(0).getType();
1004  }
1005  return success();
1006 }
1007 
1008 //===----------------------------------------------------------------------===//
1009 // spirv.AccessChainOp
1010 //===----------------------------------------------------------------------===//
1011 
1012 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
1013  auto ptrType = type.dyn_cast<spirv::PointerType>();
1014  if (!ptrType) {
1015  emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
1016  "to composite type, but provided ")
1017  << type;
1018  return nullptr;
1019  }
1020 
1021  auto resultType = ptrType.getPointeeType();
1022  auto resultStorageClass = ptrType.getStorageClass();
1023  int32_t index = 0;
1024 
1025  for (auto indexSSA : indices) {
1026  auto cType = resultType.dyn_cast<spirv::CompositeType>();
1027  if (!cType) {
1028  emitError(
1029  baseLoc,
1030  "'spirv.AccessChain' op cannot extract from non-composite type ")
1031  << resultType << " with index " << index;
1032  return nullptr;
1033  }
1034  index = 0;
1035  if (resultType.isa<spirv::StructType>()) {
1036  Operation *op = indexSSA.getDefiningOp();
1037  if (!op) {
1038  emitError(baseLoc, "'spirv.AccessChain' op index must be an "
1039  "integer spirv.Constant to access "
1040  "element of spirv.struct");
1041  return nullptr;
1042  }
1043 
1044  // TODO: this should be relaxed to allow
1045  // integer literals of other bitwidths.
1046  if (failed(extractValueFromConstOp(op, index))) {
1047  emitError(
1048  baseLoc,
1049  "'spirv.AccessChain' index must be an integer spirv.Constant to "
1050  "access element of spirv.struct, but provided ")
1051  << op->getName();
1052  return nullptr;
1053  }
1054  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1055  emitError(baseLoc, "'spirv.AccessChain' op index ")
1056  << index << " out of bounds for " << resultType;
1057  return nullptr;
1058  }
1059  }
1060  resultType = cType.getElementType(index);
1061  }
1062  return spirv::PointerType::get(resultType, resultStorageClass);
1063 }
1064 
1065 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1066  Value basePtr, ValueRange indices) {
1067  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1068  assert(type && "Unable to deduce return type based on basePtr and indices");
1069  build(builder, state, type, basePtr, indices);
1070 }
1071 
1072 ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1073  OperationState &result) {
1076  Type type;
1077  auto loc = parser.getCurrentLocation();
1078  SmallVector<Type, 4> indicesTypes;
1079 
1080  if (parser.parseOperand(ptrInfo) ||
1081  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1082  parser.parseColonType(type) ||
1083  parser.resolveOperand(ptrInfo, type, result.operands)) {
1084  return failure();
1085  }
1086 
1087  // Check that the provided indices list is not empty before parsing their
1088  // type list.
1089  if (indicesInfo.empty()) {
1090  return mlir::emitError(result.location,
1091  "'spirv.AccessChain' op expected at "
1092  "least one index ");
1093  }
1094 
1095  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1096  return failure();
1097 
1098  // Check that the indices types list is not empty and that it has a one-to-one
1099  // mapping to the provided indices.
1100  if (indicesTypes.size() != indicesInfo.size()) {
1101  return mlir::emitError(
1102  result.location, "'spirv.AccessChain' op indices types' count must be "
1103  "equal to indices info count");
1104  }
1105 
1106  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
1107  return failure();
1108 
1109  auto resultType = getElementPtrType(
1110  type, llvm::ArrayRef(result.operands).drop_front(), result.location);
1111  if (!resultType) {
1112  return failure();
1113  }
1114 
1115  result.addTypes(resultType);
1116  return success();
1117 }
1118 
1119 template <typename Op>
1120 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1121  printer << ' ' << op.getBasePtr() << '[' << indices
1122  << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
1123 }
1124 
1126  printAccessChain(*this, getIndices(), printer);
1127 }
1128 
1129 template <typename Op>
1130 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1131  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
1132  indices, accessChainOp.getLoc());
1133  if (!resultType)
1134  return failure();
1135 
1136  auto providedResultType =
1137  accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1138  if (!providedResultType)
1139  return accessChainOp.emitOpError(
1140  "result type must be a pointer, but provided")
1141  << providedResultType;
1142 
1143  if (resultType != providedResultType)
1144  return accessChainOp.emitOpError("invalid result type: expected ")
1145  << resultType << ", but provided " << providedResultType;
1146 
1147  return success();
1148 }
1149 
1151  return verifyAccessChain(*this, getIndices());
1152 }
1153 
1154 //===----------------------------------------------------------------------===//
1155 // spirv.mlir.addressof
1156 //===----------------------------------------------------------------------===//
1157 
1158 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1159  spirv::GlobalVariableOp var) {
1160  build(builder, state, var.getType(), SymbolRefAttr::get(var));
1161 }
1162 
1164  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1165  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1166  getVariableAttr()));
1167  if (!varOp) {
1168  return emitOpError("expected spirv.GlobalVariable symbol");
1169  }
1170  if (getPointer().getType() != varOp.getType()) {
1171  return emitOpError(
1172  "result type mismatch with the referenced global variable's type");
1173  }
1174  return success();
1175 }
1176 
1177 template <typename T>
1178 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1179  printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
1180  << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
1181  << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
1182  << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
1183 }
1184 
1186  OperationState &state) {
1187  spirv::Scope memoryScope;
1188  spirv::MemorySemantics equalSemantics, unequalSemantics;
1190  Type type;
1191  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1193  parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1194  equalSemantics, parser, state, kEqualSemanticsAttrName) ||
1195  parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1196  unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
1197  parser.parseOperandList(operandInfo, 3))
1198  return failure();
1199 
1200  auto loc = parser.getCurrentLocation();
1201  if (parser.parseColonType(type))
1202  return failure();
1203 
1204  auto ptrType = type.dyn_cast<spirv::PointerType>();
1205  if (!ptrType)
1206  return parser.emitError(loc, "expected pointer type");
1207 
1208  if (parser.resolveOperands(
1209  operandInfo,
1210  {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1211  parser.getNameLoc(), state.operands))
1212  return failure();
1213 
1214  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1215 }
1216 
1217 template <typename T>
1219  // According to the spec:
1220  // "The type of Value must be the same as Result Type. The type of the value
1221  // pointed to by Pointer must be the same as Result Type. This type must also
1222  // match the type of Comparator."
1223  if (atomOp.getType() != atomOp.getValue().getType())
1224  return atomOp.emitOpError("value operand must have the same type as the op "
1225  "result, but found ")
1226  << atomOp.getValue().getType() << " vs " << atomOp.getType();
1227 
1228  if (atomOp.getType() != atomOp.getComparator().getType())
1229  return atomOp.emitOpError(
1230  "comparator operand must have the same type as the op "
1231  "result, but found ")
1232  << atomOp.getComparator().getType() << " vs " << atomOp.getType();
1233 
1234  Type pointeeType = atomOp.getPointer()
1235  .getType()
1236  .template cast<spirv::PointerType>()
1237  .getPointeeType();
1238  if (atomOp.getType() != pointeeType)
1239  return atomOp.emitOpError(
1240  "pointer operand's pointee type must have the same "
1241  "as the op result type, but found ")
1242  << pointeeType << " vs " << atomOp.getType();
1243 
1244  // TODO: Unequal cannot be set to Release or Acquire and Release.
1245  // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1246 
1247  return success();
1248 }
1249 
1250 //===----------------------------------------------------------------------===//
1251 // spirv.AtomicAndOp
1252 //===----------------------------------------------------------------------===//
1253 
1255  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1256 }
1257 
1258 ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1259  OperationState &result) {
1260  return ::parseAtomicUpdateOp(parser, result, true);
1261 }
1263  ::printAtomicUpdateOp(*this, p);
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // spirv.AtomicCompareExchangeOp
1268 //===----------------------------------------------------------------------===//
1269 
1272 }
1273 
1274 ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1275  OperationState &result) {
1277 }
1280 }
1281 
1282 //===----------------------------------------------------------------------===//
1283 // spirv.AtomicCompareExchangeWeakOp
1284 //===----------------------------------------------------------------------===//
1285 
1288 }
1289 
1290 ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1291  OperationState &result) {
1293 }
1296 }
1297 
1298 //===----------------------------------------------------------------------===//
1299 // spirv.AtomicExchange
1300 //===----------------------------------------------------------------------===//
1301 
1303  printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
1304  << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
1305  << " : " << getPointer().getType();
1306 }
1307 
1308 ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1309  OperationState &result) {
1310  spirv::Scope memoryScope;
1311  spirv::MemorySemantics semantics;
1313  Type type;
1314  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1316  parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1317  kSemanticsAttrName) ||
1318  parser.parseOperandList(operandInfo, 2))
1319  return failure();
1320 
1321  auto loc = parser.getCurrentLocation();
1322  if (parser.parseColonType(type))
1323  return failure();
1324 
1325  auto ptrType = type.dyn_cast<spirv::PointerType>();
1326  if (!ptrType)
1327  return parser.emitError(loc, "expected pointer type");
1328 
1329  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1330  parser.getNameLoc(), result.operands))
1331  return failure();
1332 
1333  return parser.addTypeToList(ptrType.getPointeeType(), result.types);
1334 }
1335 
1337  if (getType() != getValue().getType())
1338  return emitOpError("value operand must have the same type as the op "
1339  "result, but found ")
1340  << getValue().getType() << " vs " << getType();
1341 
1342  Type pointeeType =
1343  getPointer().getType().cast<spirv::PointerType>().getPointeeType();
1344  if (getType() != pointeeType)
1345  return emitOpError("pointer operand's pointee type must have the same "
1346  "as the op result type, but found ")
1347  << pointeeType << " vs " << getType();
1348 
1349  return success();
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // spirv.AtomicIAddOp
1354 //===----------------------------------------------------------------------===//
1355 
1357  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1358 }
1359 
1360 ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1361  OperationState &result) {
1362  return ::parseAtomicUpdateOp(parser, result, true);
1363 }
1365  ::printAtomicUpdateOp(*this, p);
1366 }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // spirv.EXT.AtomicFAddOp
1370 //===----------------------------------------------------------------------===//
1371 
1373  return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1374 }
1375 
1376 ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
1377  OperationState &result) {
1378  return ::parseAtomicUpdateOp(parser, result, true);
1379 }
1381  ::printAtomicUpdateOp(*this, p);
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // spirv.AtomicIDecrementOp
1386 //===----------------------------------------------------------------------===//
1387 
1389  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1390 }
1391 
1392 ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1393  OperationState &result) {
1394  return ::parseAtomicUpdateOp(parser, result, false);
1395 }
1397  ::printAtomicUpdateOp(*this, p);
1398 }
1399 
1400 //===----------------------------------------------------------------------===//
1401 // spirv.AtomicIIncrementOp
1402 //===----------------------------------------------------------------------===//
1403 
1405  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1406 }
1407 
1408 ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1409  OperationState &result) {
1410  return ::parseAtomicUpdateOp(parser, result, false);
1411 }
1413  ::printAtomicUpdateOp(*this, p);
1414 }
1415 
1416 //===----------------------------------------------------------------------===//
1417 // spirv.AtomicISubOp
1418 //===----------------------------------------------------------------------===//
1419 
1421  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1422 }
1423 
1424 ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1425  OperationState &result) {
1426  return ::parseAtomicUpdateOp(parser, result, true);
1427 }
1429  ::printAtomicUpdateOp(*this, p);
1430 }
1431 
1432 //===----------------------------------------------------------------------===//
1433 // spirv.AtomicOrOp
1434 //===----------------------------------------------------------------------===//
1435 
1437  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1438 }
1439 
1440 ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1441  OperationState &result) {
1442  return ::parseAtomicUpdateOp(parser, result, true);
1443 }
1445  ::printAtomicUpdateOp(*this, p);
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // spirv.AtomicSMaxOp
1450 //===----------------------------------------------------------------------===//
1451 
1453  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1454 }
1455 
1456 ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1457  OperationState &result) {
1458  return ::parseAtomicUpdateOp(parser, result, true);
1459 }
1461  ::printAtomicUpdateOp(*this, p);
1462 }
1463 
1464 //===----------------------------------------------------------------------===//
1465 // spirv.AtomicSMinOp
1466 //===----------------------------------------------------------------------===//
1467 
1469  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1470 }
1471 
1472 ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1473  OperationState &result) {
1474  return ::parseAtomicUpdateOp(parser, result, true);
1475 }
1477  ::printAtomicUpdateOp(*this, p);
1478 }
1479 
1480 //===----------------------------------------------------------------------===//
1481 // spirv.AtomicUMaxOp
1482 //===----------------------------------------------------------------------===//
1483 
1485  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1486 }
1487 
1488 ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1489  OperationState &result) {
1490  return ::parseAtomicUpdateOp(parser, result, true);
1491 }
1493  ::printAtomicUpdateOp(*this, p);
1494 }
1495 
1496 //===----------------------------------------------------------------------===//
1497 // spirv.AtomicUMinOp
1498 //===----------------------------------------------------------------------===//
1499 
1501  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1502 }
1503 
1504 ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1505  OperationState &result) {
1506  return ::parseAtomicUpdateOp(parser, result, true);
1507 }
1509  ::printAtomicUpdateOp(*this, p);
1510 }
1511 
1512 //===----------------------------------------------------------------------===//
1513 // spirv.AtomicXorOp
1514 //===----------------------------------------------------------------------===//
1515 
1517  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1518 }
1519 
1520 ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1521  OperationState &result) {
1522  return ::parseAtomicUpdateOp(parser, result, true);
1523 }
1525  ::printAtomicUpdateOp(*this, p);
1526 }
1527 
1528 //===----------------------------------------------------------------------===//
1529 // spirv.BitcastOp
1530 //===----------------------------------------------------------------------===//
1531 
1533  // TODO: The SPIR-V spec validation rules are different for different
1534  // versions.
1535  auto operandType = getOperand().getType();
1536  auto resultType = getResult().getType();
1537  if (operandType == resultType) {
1538  return emitError("result type must be different from operand type");
1539  }
1540  if (operandType.isa<spirv::PointerType>() &&
1541  !resultType.isa<spirv::PointerType>()) {
1542  return emitError(
1543  "unhandled bit cast conversion from pointer type to non-pointer type");
1544  }
1545  if (!operandType.isa<spirv::PointerType>() &&
1546  resultType.isa<spirv::PointerType>()) {
1547  return emitError(
1548  "unhandled bit cast conversion from non-pointer type to pointer type");
1549  }
1550  auto operandBitWidth = getBitWidth(operandType);
1551  auto resultBitWidth = getBitWidth(resultType);
1552  if (operandBitWidth != resultBitWidth) {
1553  return emitOpError("mismatch in result type bitwidth ")
1554  << resultBitWidth << " and operand type bitwidth "
1555  << operandBitWidth;
1556  }
1557  return success();
1558 }
1559 
1560 //===----------------------------------------------------------------------===//
1561 // spirv.PtrCastToGenericOp
1562 //===----------------------------------------------------------------------===//
1563 
1565  auto operandType = getPointer().getType().cast<spirv::PointerType>();
1566  auto resultType = getResult().getType().cast<spirv::PointerType>();
1567 
1568  spirv::StorageClass operandStorage = operandType.getStorageClass();
1569  if (operandStorage != spirv::StorageClass::Workgroup &&
1570  operandStorage != spirv::StorageClass::CrossWorkgroup &&
1571  operandStorage != spirv::StorageClass::Function)
1572  return emitError("pointer must point to the Workgroup, CrossWorkgroup"
1573  ", or Function Storage Class");
1574 
1575  spirv::StorageClass resultStorage = resultType.getStorageClass();
1576  if (resultStorage != spirv::StorageClass::Generic)
1577  return emitError("result type must be of storage class Generic");
1578 
1579  Type operandPointeeType = operandType.getPointeeType();
1580  Type resultPointeeType = resultType.getPointeeType();
1581  if (operandPointeeType != resultPointeeType)
1582  return emitOpError("pointer operand's pointee type must have the same "
1583  "as the op result type, but found ")
1584  << operandPointeeType << " vs " << resultPointeeType;
1585  return success();
1586 }
1587 
1588 //===----------------------------------------------------------------------===//
1589 // spirv.GenericCastToPtrOp
1590 //===----------------------------------------------------------------------===//
1591 
1593  auto operandType = getPointer().getType().cast<spirv::PointerType>();
1594  auto resultType = getResult().getType().cast<spirv::PointerType>();
1595 
1596  spirv::StorageClass operandStorage = operandType.getStorageClass();
1597  if (operandStorage != spirv::StorageClass::Generic)
1598  return emitError("pointer type must be of storage class Generic");
1599 
1600  spirv::StorageClass resultStorage = resultType.getStorageClass();
1601  if (resultStorage != spirv::StorageClass::Workgroup &&
1602  resultStorage != spirv::StorageClass::CrossWorkgroup &&
1603  resultStorage != spirv::StorageClass::Function)
1604  return emitError("result must point to the Workgroup, CrossWorkgroup, "
1605  "or Function Storage Class");
1606 
1607  Type operandPointeeType = operandType.getPointeeType();
1608  Type resultPointeeType = resultType.getPointeeType();
1609  if (operandPointeeType != resultPointeeType)
1610  return emitOpError("pointer operand's pointee type must have the same "
1611  "as the op result type, but found ")
1612  << operandPointeeType << " vs " << resultPointeeType;
1613  return success();
1614 }
1615 
1616 //===----------------------------------------------------------------------===//
1617 // spirv.GenericCastToPtrExplicitOp
1618 //===----------------------------------------------------------------------===//
1619 
1621  auto operandType = getPointer().getType().cast<spirv::PointerType>();
1622  auto resultType = getResult().getType().cast<spirv::PointerType>();
1623 
1624  spirv::StorageClass operandStorage = operandType.getStorageClass();
1625  if (operandStorage != spirv::StorageClass::Generic)
1626  return emitError("pointer type must be of storage class Generic");
1627 
1628  spirv::StorageClass resultStorage = resultType.getStorageClass();
1629  if (resultStorage != spirv::StorageClass::Workgroup &&
1630  resultStorage != spirv::StorageClass::CrossWorkgroup &&
1631  resultStorage != spirv::StorageClass::Function)
1632  return emitError("result must point to the Workgroup, CrossWorkgroup, "
1633  "or Function Storage Class");
1634 
1635  Type operandPointeeType = operandType.getPointeeType();
1636  Type resultPointeeType = resultType.getPointeeType();
1637  if (operandPointeeType != resultPointeeType)
1638  return emitOpError("pointer operand's pointee type must have the same "
1639  "as the op result type, but found ")
1640  << operandPointeeType << " vs " << resultPointeeType;
1641  return success();
1642 }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // spirv.BranchOp
1646 //===----------------------------------------------------------------------===//
1647 
1648 SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1649  assert(index == 0 && "invalid successor index");
1650  return SuccessorOperands(0, getTargetOperandsMutable());
1651 }
1652 
1653 //===----------------------------------------------------------------------===//
1654 // spirv.BranchConditionalOp
1655 //===----------------------------------------------------------------------===//
1656 
1658 spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1659  assert(index < 2 && "invalid successor index");
1660  return SuccessorOperands(index == kTrueIndex
1661  ? getTrueTargetOperandsMutable()
1662  : getFalseTargetOperandsMutable());
1663 }
1664 
1665 ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1666  OperationState &result) {
1667  auto &builder = parser.getBuilder();
1669  Block *dest;
1670 
1671  // Parse the condition.
1672  Type boolTy = builder.getI1Type();
1673  if (parser.parseOperand(condInfo) ||
1674  parser.resolveOperand(condInfo, boolTy, result.operands))
1675  return failure();
1676 
1677  // Parse the optional branch weights.
1678  if (succeeded(parser.parseOptionalLSquare())) {
1679  IntegerAttr trueWeight, falseWeight;
1680  NamedAttrList weights;
1681 
1682  auto i32Type = builder.getIntegerType(32);
1683  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1684  parser.parseComma() ||
1685  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1686  parser.parseRSquare())
1687  return failure();
1688 
1690  builder.getArrayAttr({trueWeight, falseWeight}));
1691  }
1692 
1693  // Parse the true branch.
1694  SmallVector<Value, 4> trueOperands;
1695  if (parser.parseComma() ||
1696  parser.parseSuccessorAndUseList(dest, trueOperands))
1697  return failure();
1698  result.addSuccessors(dest);
1699  result.addOperands(trueOperands);
1700 
1701  // Parse the false branch.
1702  SmallVector<Value, 4> falseOperands;
1703  if (parser.parseComma() ||
1704  parser.parseSuccessorAndUseList(dest, falseOperands))
1705  return failure();
1706  result.addSuccessors(dest);
1707  result.addOperands(falseOperands);
1708  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1709  builder.getDenseI32ArrayAttr(
1710  {1, static_cast<int32_t>(trueOperands.size()),
1711  static_cast<int32_t>(falseOperands.size())}));
1712 
1713  return success();
1714 }
1715 
1717  printer << ' ' << getCondition();
1718 
1719  if (auto weights = getBranchWeights()) {
1720  printer << " [";
1721  llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1722  printer << a.cast<IntegerAttr>().getInt();
1723  });
1724  printer << "]";
1725  }
1726 
1727  printer << ", ";
1728  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1729  printer << ", ";
1730  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1731 }
1732 
1734  if (auto weights = getBranchWeights()) {
1735  if (weights->getValue().size() != 2) {
1736  return emitOpError("must have exactly two branch weights");
1737  }
1738  if (llvm::all_of(*weights, [](Attribute attr) {
1739  return attr.cast<IntegerAttr>().getValue().isZero();
1740  }))
1741  return emitOpError("branch weights cannot both be zero");
1742  }
1743 
1744  return success();
1745 }
1746 
1747 //===----------------------------------------------------------------------===//
1748 // spirv.CompositeConstruct
1749 //===----------------------------------------------------------------------===//
1750 
1752  auto cType = getType().cast<spirv::CompositeType>();
1753  operand_range constituents = this->getConstituents();
1754 
1755  if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1756  if (constituents.size() != 1)
1757  return emitOpError("has incorrect number of operands: expected ")
1758  << "1, but provided " << constituents.size();
1759  if (coopType.getElementType() != constituents.front().getType())
1760  return emitOpError("operand type mismatch: expected operand type ")
1761  << coopType.getElementType() << ", but provided "
1762  << constituents.front().getType();
1763  return success();
1764  }
1765 
1766  if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
1767  if (constituents.size() != 1)
1768  return emitOpError("has incorrect number of operands: expected ")
1769  << "1, but provided " << constituents.size();
1770  if (jointType.getElementType() != constituents.front().getType())
1771  return emitOpError("operand type mismatch: expected operand type ")
1772  << jointType.getElementType() << ", but provided "
1773  << constituents.front().getType();
1774  return success();
1775  }
1776 
1777  if (constituents.size() == cType.getNumElements()) {
1778  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1779  if (constituents[index].getType() != cType.getElementType(index)) {
1780  return emitOpError("operand type mismatch: expected operand type ")
1781  << cType.getElementType(index) << ", but provided "
1782  << constituents[index].getType();
1783  }
1784  }
1785  return success();
1786  }
1787 
1788  // If not constructing a cooperative matrix type, then we must be constructing
1789  // a vector type.
1790  auto resultType = cType.dyn_cast<VectorType>();
1791  if (!resultType)
1792  return emitOpError(
1793  "expected to return a vector or cooperative matrix when the number of "
1794  "constituents is less than what the result needs");
1795 
1796  SmallVector<unsigned> sizes;
1797  for (Value component : constituents) {
1798  if (!component.getType().isa<VectorType>() &&
1799  !component.getType().isIntOrFloat())
1800  return emitOpError("operand type mismatch: expected operand to have "
1801  "a scalar or vector type, but provided ")
1802  << component.getType();
1803 
1804  Type elementType = component.getType();
1805  if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
1806  sizes.push_back(vectorType.getNumElements());
1807  elementType = vectorType.getElementType();
1808  } else {
1809  sizes.push_back(1);
1810  }
1811 
1812  if (elementType != resultType.getElementType())
1813  return emitOpError("operand element type mismatch: expected to be ")
1814  << resultType.getElementType() << ", but provided " << elementType;
1815  }
1816  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1817  if (totalCount != cType.getNumElements())
1818  return emitOpError("has incorrect number of operands: expected ")
1819  << cType.getNumElements() << ", but provided " << totalCount;
1820  return success();
1821 }
1822 
1823 //===----------------------------------------------------------------------===//
1824 // spirv.CompositeExtractOp
1825 //===----------------------------------------------------------------------===//
1826 
1827 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1828  Value composite,
1829  ArrayRef<int32_t> indices) {
1830  auto indexAttr = builder.getI32ArrayAttr(indices);
1831  auto elementType =
1832  getElementType(composite.getType(), indexAttr, state.location);
1833  if (!elementType) {
1834  return;
1835  }
1836  build(builder, state, elementType, composite, indexAttr);
1837 }
1838 
1839 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1840  OperationState &result) {
1841  OpAsmParser::UnresolvedOperand compositeInfo;
1842  Attribute indicesAttr;
1843  Type compositeType;
1844  SMLoc attrLocation;
1845 
1846  if (parser.parseOperand(compositeInfo) ||
1847  parser.getCurrentLocation(&attrLocation) ||
1848  parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1849  parser.parseColonType(compositeType) ||
1850  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
1851  return failure();
1852  }
1853 
1854  Type resultType =
1855  getElementType(compositeType, indicesAttr, parser, attrLocation);
1856  if (!resultType) {
1857  return failure();
1858  }
1859  result.addTypes(resultType);
1860  return success();
1861 }
1862 
1864  printer << ' ' << getComposite() << getIndices() << " : "
1865  << getComposite().getType();
1866 }
1867 
1869  auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1870  auto resultType =
1871  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1872  if (!resultType)
1873  return failure();
1874 
1875  if (resultType != getType()) {
1876  return emitOpError("invalid result type: expected ")
1877  << resultType << " but provided " << getType();
1878  }
1879 
1880  return success();
1881 }
1882 
1883 //===----------------------------------------------------------------------===//
1884 // spirv.CompositeInsert
1885 //===----------------------------------------------------------------------===//
1886 
1887 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1888  Value object, Value composite,
1889  ArrayRef<int32_t> indices) {
1890  auto indexAttr = builder.getI32ArrayAttr(indices);
1891  build(builder, state, composite.getType(), object, composite, indexAttr);
1892 }
1893 
1894 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1895  OperationState &result) {
1897  Type objectType, compositeType;
1898  Attribute indicesAttr;
1899  auto loc = parser.getCurrentLocation();
1900 
1901  return failure(
1902  parser.parseOperandList(operands, 2) ||
1903  parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1904  parser.parseColonType(objectType) ||
1905  parser.parseKeywordType("into", compositeType) ||
1906  parser.resolveOperands(operands, {objectType, compositeType}, loc,
1907  result.operands) ||
1908  parser.addTypesToList(compositeType, result.types));
1909 }
1910 
1912  auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1913  auto objectType =
1914  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1915  if (!objectType)
1916  return failure();
1917 
1918  if (objectType != getObject().getType()) {
1919  return emitOpError("object operand type should be ")
1920  << objectType << ", but found " << getObject().getType();
1921  }
1922 
1923  if (getComposite().getType() != getType()) {
1924  return emitOpError("result type should be the same as "
1925  "the composite type, but found ")
1926  << getComposite().getType() << " vs " << getType();
1927  }
1928 
1929  return success();
1930 }
1931 
1933  printer << " " << getObject() << ", " << getComposite() << getIndices()
1934  << " : " << getObject().getType() << " into "
1935  << getComposite().getType();
1936 }
1937 
1938 //===----------------------------------------------------------------------===//
1939 // spirv.Constant
1940 //===----------------------------------------------------------------------===//
1941 
1942 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1943  OperationState &result) {
1944  Attribute value;
1945  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
1946  return failure();
1947 
1948  Type type = NoneType::get(parser.getContext());
1949  if (auto typedAttr = value.dyn_cast<TypedAttr>())
1950  type = typedAttr.getType();
1951  if (type.isa<NoneType, TensorType>()) {
1952  if (parser.parseColonType(type))
1953  return failure();
1954  }
1955 
1956  return parser.addTypeToList(type, result.types);
1957 }
1958 
1959 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1960  printer << ' ' << getValue();
1961  if (getType().isa<spirv::ArrayType>())
1962  printer << " : " << getType();
1963 }
1964 
1965 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1966  Type opType) {
1967  if (value.isa<IntegerAttr, FloatAttr>()) {
1968  auto valueType = value.cast<TypedAttr>().getType();
1969  if (valueType != opType)
1970  return op.emitOpError("result type (")
1971  << opType << ") does not match value type (" << valueType << ")";
1972  return success();
1973  }
1974  if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1975  auto valueType = value.cast<TypedAttr>().getType();
1976  if (valueType == opType)
1977  return success();
1978  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1979  auto shapedType = valueType.dyn_cast<ShapedType>();
1980  if (!arrayType)
1981  return op.emitOpError("result or element type (")
1982  << opType << ") does not match value type (" << valueType
1983  << "), must be the same or spirv.array";
1984 
1985  int numElements = arrayType.getNumElements();
1986  auto opElemType = arrayType.getElementType();
1987  while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1988  numElements *= t.getNumElements();
1989  opElemType = t.getElementType();
1990  }
1991  if (!opElemType.isIntOrFloat())
1992  return op.emitOpError("only support nested array result type");
1993 
1994  auto valueElemType = shapedType.getElementType();
1995  if (valueElemType != opElemType) {
1996  return op.emitOpError("result element type (")
1997  << opElemType << ") does not match value element type ("
1998  << valueElemType << ")";
1999  }
2000 
2001  if (numElements != shapedType.getNumElements()) {
2002  return op.emitOpError("result number of elements (")
2003  << numElements << ") does not match value number of elements ("
2004  << shapedType.getNumElements() << ")";
2005  }
2006  return success();
2007  }
2008  if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
2009  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
2010  if (!arrayType)
2011  return op.emitOpError(
2012  "must have spirv.array result type for array value");
2013  Type elemType = arrayType.getElementType();
2014  for (Attribute element : arrayAttr.getValue()) {
2015  // Verify array elements recursively.
2016  if (failed(verifyConstantType(op, element, elemType)))
2017  return failure();
2018  }
2019  return success();
2020  }
2021  return op.emitOpError("cannot have attribute: ") << value;
2022 }
2023 
2025  // ODS already generates checks to make sure the result type is valid. We just
2026  // need to additionally check that the value's attribute type is consistent
2027  // with the result type.
2028  return verifyConstantType(*this, getValueAttr(), getType());
2029 }
2030 
2031 bool spirv::ConstantOp::isBuildableWith(Type type) {
2032  // Must be valid SPIR-V type first.
2033  if (!type.isa<spirv::SPIRVType>())
2034  return false;
2035 
2036  if (isa<SPIRVDialect>(type.getDialect())) {
2037  // TODO: support constant struct
2038  return type.isa<spirv::ArrayType>();
2039  }
2040 
2041  return true;
2042 }
2043 
2044 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
2045  OpBuilder &builder) {
2046  if (auto intType = type.dyn_cast<IntegerType>()) {
2047  unsigned width = intType.getWidth();
2048  if (width == 1)
2049  return builder.create<spirv::ConstantOp>(loc, type,
2050  builder.getBoolAttr(false));
2051  return builder.create<spirv::ConstantOp>(
2052  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
2053  }
2054  if (auto floatType = type.dyn_cast<FloatType>()) {
2055  return builder.create<spirv::ConstantOp>(
2056  loc, type, builder.getFloatAttr(floatType, 0.0));
2057  }
2058  if (auto vectorType = type.dyn_cast<VectorType>()) {
2059  Type elemType = vectorType.getElementType();
2060  if (elemType.isa<IntegerType>()) {
2061  return builder.create<spirv::ConstantOp>(
2062  loc, type,
2063  DenseElementsAttr::get(vectorType,
2064  IntegerAttr::get(elemType, 0).getValue()));
2065  }
2066  if (elemType.isa<FloatType>()) {
2067  return builder.create<spirv::ConstantOp>(
2068  loc, type,
2069  DenseFPElementsAttr::get(vectorType,
2070  FloatAttr::get(elemType, 0.0).getValue()));
2071  }
2072  }
2073 
2074  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
2075 }
2076 
2077 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
2078  OpBuilder &builder) {
2079  if (auto intType = type.dyn_cast<IntegerType>()) {
2080  unsigned width = intType.getWidth();
2081  if (width == 1)
2082  return builder.create<spirv::ConstantOp>(loc, type,
2083  builder.getBoolAttr(true));
2084  return builder.create<spirv::ConstantOp>(
2085  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
2086  }
2087  if (auto floatType = type.dyn_cast<FloatType>()) {
2088  return builder.create<spirv::ConstantOp>(
2089  loc, type, builder.getFloatAttr(floatType, 1.0));
2090  }
2091  if (auto vectorType = type.dyn_cast<VectorType>()) {
2092  Type elemType = vectorType.getElementType();
2093  if (elemType.isa<IntegerType>()) {
2094  return builder.create<spirv::ConstantOp>(
2095  loc, type,
2096  DenseElementsAttr::get(vectorType,
2097  IntegerAttr::get(elemType, 1).getValue()));
2098  }
2099  if (elemType.isa<FloatType>()) {
2100  return builder.create<spirv::ConstantOp>(
2101  loc, type,
2102  DenseFPElementsAttr::get(vectorType,
2103  FloatAttr::get(elemType, 1.0).getValue()));
2104  }
2105  }
2106 
2107  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
2108 }
2109 
2110 void mlir::spirv::ConstantOp::getAsmResultNames(
2111  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2112  Type type = getType();
2113 
2114  SmallString<32> specialNameBuffer;
2115  llvm::raw_svector_ostream specialName(specialNameBuffer);
2116  specialName << "cst";
2117 
2118  IntegerType intTy = type.dyn_cast<IntegerType>();
2119 
2120  if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
2121  if (intTy && intTy.getWidth() == 1) {
2122  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
2123  }
2124 
2125  if (intTy.isSignless()) {
2126  specialName << intCst.getInt();
2127  } else if (intTy.isUnsigned()) {
2128  specialName << intCst.getUInt();
2129  } else {
2130  specialName << intCst.getSInt();
2131  }
2132  }
2133 
2134  if (intTy || type.isa<FloatType>()) {
2135  specialName << '_' << type;
2136  }
2137 
2138  if (auto vecType = type.dyn_cast<VectorType>()) {
2139  specialName << "_vec_";
2140  specialName << vecType.getDimSize(0);
2141 
2142  Type elementType = vecType.getElementType();
2143 
2144  if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
2145  specialName << "x" << elementType;
2146  }
2147  }
2148 
2149  setNameFn(getResult(), specialName.str());
2150 }
2151 
2152 void mlir::spirv::AddressOfOp::getAsmResultNames(
2153  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2154  SmallString<32> specialNameBuffer;
2155  llvm::raw_svector_ostream specialName(specialNameBuffer);
2156  specialName << getVariable() << "_addr";
2157  setNameFn(getResult(), specialName.str());
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // spirv.ControlBarrierOp
2162 //===----------------------------------------------------------------------===//
2163 
2165  return verifyMemorySemantics(getOperation(), getMemorySemantics());
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // spirv.ConvertFToSOp
2170 //===----------------------------------------------------------------------===//
2171 
2173  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2174  /*skipBitWidthCheck=*/true);
2175 }
2176 
2177 //===----------------------------------------------------------------------===//
2178 // spirv.ConvertFToUOp
2179 //===----------------------------------------------------------------------===//
2180 
2182  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2183  /*skipBitWidthCheck=*/true);
2184 }
2185 
2186 //===----------------------------------------------------------------------===//
2187 // spirv.ConvertSToFOp
2188 //===----------------------------------------------------------------------===//
2189 
2191  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2192  /*skipBitWidthCheck=*/true);
2193 }
2194 
2195 //===----------------------------------------------------------------------===//
2196 // spirv.ConvertUToFOp
2197 //===----------------------------------------------------------------------===//
2198 
2200  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2201  /*skipBitWidthCheck=*/true);
2202 }
2203 
2204 //===----------------------------------------------------------------------===//
2205 // spirv.EntryPoint
2206 //===----------------------------------------------------------------------===//
2207 
2208 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2209  spirv::ExecutionModel executionModel,
2210  spirv::FuncOp function,
2211  ArrayRef<Attribute> interfaceVars) {
2212  build(builder, state,
2213  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2214  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2215 }
2216 
2217 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2218  OperationState &result) {
2219  spirv::ExecutionModel execModel;
2221  SmallVector<Type, 0> idTypes;
2222  SmallVector<Attribute, 4> interfaceVars;
2223 
2224  FlatSymbolRefAttr fn;
2225  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2226  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
2227  return failure();
2228  }
2229 
2230  if (!parser.parseOptionalComma()) {
2231  // Parse the interface variables
2232  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2233  // The name of the interface variable attribute isnt important
2234  FlatSymbolRefAttr var;
2235  NamedAttrList attrs;
2236  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2237  return failure();
2238  interfaceVars.push_back(var);
2239  return success();
2240  }))
2241  return failure();
2242  }
2244  parser.getBuilder().getArrayAttr(interfaceVars));
2245  return success();
2246 }
2247 
2249  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
2250  printer.printSymbolName(getFn());
2251  auto interfaceVars = getInterface().getValue();
2252  if (!interfaceVars.empty()) {
2253  printer << ", ";
2254  llvm::interleaveComma(interfaceVars, printer);
2255  }
2256 }
2257 
2259  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2260  // verification.
2261  return success();
2262 }
2263 
2264 //===----------------------------------------------------------------------===//
2265 // spirv.ExecutionMode
2266 //===----------------------------------------------------------------------===//
2267 
2268 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2269  spirv::FuncOp function,
2270  spirv::ExecutionMode executionMode,
2271  ArrayRef<int32_t> params) {
2272  build(builder, state, SymbolRefAttr::get(function),
2273  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2274  builder.getI32ArrayAttr(params));
2275 }
2276 
2277 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2278  OperationState &result) {
2279  spirv::ExecutionMode execMode;
2280  Attribute fn;
2281  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
2282  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2283  return failure();
2284  }
2285 
2286  SmallVector<int32_t, 4> values;
2287  Type i32Type = parser.getBuilder().getIntegerType(32);
2288  while (!parser.parseOptionalComma()) {
2289  NamedAttrList attr;
2290  Attribute value;
2291  if (parser.parseAttribute(value, i32Type, "value", attr)) {
2292  return failure();
2293  }
2294  values.push_back(value.cast<IntegerAttr>().getInt());
2295  }
2297  parser.getBuilder().getI32ArrayAttr(values));
2298  return success();
2299 }
2300 
2302  printer << " ";
2303  printer.printSymbolName(getFn());
2304  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
2305  auto values = this->getValues();
2306  if (values.empty())
2307  return;
2308  printer << ", ";
2309  llvm::interleaveComma(values, printer, [&](Attribute a) {
2310  printer << a.cast<IntegerAttr>().getInt();
2311  });
2312 }
2313 
2314 //===----------------------------------------------------------------------===//
2315 // spirv.FConvertOp
2316 //===----------------------------------------------------------------------===//
2317 
2319  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2320 }
2321 
2322 //===----------------------------------------------------------------------===//
2323 // spirv.SConvertOp
2324 //===----------------------------------------------------------------------===//
2325 
2327  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2328 }
2329 
2330 //===----------------------------------------------------------------------===//
2331 // spirv.UConvertOp
2332 //===----------------------------------------------------------------------===//
2333 
2335  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2336 }
2337 
2338 //===----------------------------------------------------------------------===//
2339 // spirv.func
2340 //===----------------------------------------------------------------------===//
2341 
2342 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2344  SmallVector<DictionaryAttr> resultAttrs;
2345  SmallVector<Type> resultTypes;
2346  auto &builder = parser.getBuilder();
2347 
2348  // Parse the name as a symbol.
2349  StringAttr nameAttr;
2350  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2351  result.attributes))
2352  return failure();
2353 
2354  // Parse the function signature.
2355  bool isVariadic = false;
2357  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2358  resultAttrs))
2359  return failure();
2360 
2361  SmallVector<Type> argTypes;
2362  for (auto &arg : entryArgs)
2363  argTypes.push_back(arg.type);
2364  auto fnType = builder.getFunctionType(argTypes, resultTypes);
2365  result.addAttribute(getFunctionTypeAttrName(result.name),
2366  TypeAttr::get(fnType));
2367 
2368  // Parse the optional function control keyword.
2369  spirv::FunctionControl fnControl;
2370  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2371  return failure();
2372 
2373  // If additional attributes are present, parse them.
2374  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2375  return failure();
2376 
2377  // Add the attributes to the function arguments.
2378  assert(resultAttrs.size() == resultTypes.size());
2380  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
2381  getResAttrsAttrName(result.name));
2382 
2383  // Parse the optional function body.
2384  auto *body = result.addRegion();
2385  OptionalParseResult parseResult =
2386  parser.parseOptionalRegion(*body, entryArgs);
2387  return failure(parseResult.has_value() && failed(*parseResult));
2388 }
2389 
2390 void spirv::FuncOp::print(OpAsmPrinter &printer) {
2391  // Print function name, signature, and control.
2392  printer << " ";
2393  printer.printSymbolName(getSymName());
2394  auto fnType = getFunctionType();
2396  printer, *this, fnType.getInputs(),
2397  /*isVariadic=*/false, fnType.getResults());
2398  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
2399  << "\"";
2401  printer, *this,
2402  {spirv::attributeName<spirv::FunctionControl>(),
2403  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2404  getFunctionControlAttrName()});
2405 
2406  // Print the body if this is not an external function.
2407  Region &body = this->getBody();
2408  if (!body.empty()) {
2409  printer << ' ';
2410  printer.printRegion(body, /*printEntryBlockArgs=*/false,
2411  /*printBlockTerminators=*/true);
2412  }
2413 }
2414 
2415 LogicalResult spirv::FuncOp::verifyType() {
2416  if (getFunctionType().getNumResults() > 1)
2417  return emitOpError("cannot have more than one result");
2418  return success();
2419 }
2420 
2421 LogicalResult spirv::FuncOp::verifyBody() {
2422  FunctionType fnType = getFunctionType();
2423 
2424  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2425  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2426  if (fnType.getNumResults() != 0)
2427  return retOp.emitOpError("cannot be used in functions returning value");
2428  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2429  if (fnType.getNumResults() != 1)
2430  return retOp.emitOpError(
2431  "returns 1 value but enclosing function requires ")
2432  << fnType.getNumResults() << " results";
2433 
2434  auto retOperandType = retOp.getValue().getType();
2435  auto fnResultType = fnType.getResult(0);
2436  if (retOperandType != fnResultType)
2437  return retOp.emitOpError(" return value's type (")
2438  << retOperandType << ") mismatch with function's result type ("
2439  << fnResultType << ")";
2440  }
2441  return WalkResult::advance();
2442  });
2443 
2444  // TODO: verify other bits like linkage type.
2445 
2446  return failure(walkResult.wasInterrupted());
2447 }
2448 
2449 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2450  StringRef name, FunctionType type,
2451  spirv::FunctionControl control,
2452  ArrayRef<NamedAttribute> attrs) {
2454  builder.getStringAttr(name));
2455  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
2456  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2457  builder.getAttr<spirv::FunctionControlAttr>(control));
2458  state.attributes.append(attrs.begin(), attrs.end());
2459  state.addRegion();
2460 }
2461 
2462 // CallableOpInterface
2463 Region *spirv::FuncOp::getCallableRegion() {
2464  return isExternal() ? nullptr : &getBody();
2465 }
2466 
2467 // CallableOpInterface
2468 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2469  return getFunctionType().getResults();
2470 }
2471 
2472 // CallableOpInterface
2473 ::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
2474  return getArgAttrs().value_or(nullptr);
2475 }
2476 
2477 // CallableOpInterface
2478 ::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
2479  return getResAttrs().value_or(nullptr);
2480 }
2481 
2482 //===----------------------------------------------------------------------===//
2483 // spirv.FunctionCall
2484 //===----------------------------------------------------------------------===//
2485 
2487  auto fnName = getCalleeAttr();
2488 
2489  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2490  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2491  if (!funcOp) {
2492  return emitOpError("callee function '")
2493  << fnName.getValue() << "' not found in nearest symbol table";
2494  }
2495 
2496  auto functionType = funcOp.getFunctionType();
2497 
2498  if (getNumResults() > 1) {
2499  return emitOpError(
2500  "expected callee function to have 0 or 1 result, but provided ")
2501  << getNumResults();
2502  }
2503 
2504  if (functionType.getNumInputs() != getNumOperands()) {
2505  return emitOpError("has incorrect number of operands for callee: expected ")
2506  << functionType.getNumInputs() << ", but provided "
2507  << getNumOperands();
2508  }
2509 
2510  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2511  if (getOperand(i).getType() != functionType.getInput(i)) {
2512  return emitOpError("operand type mismatch: expected operand type ")
2513  << functionType.getInput(i) << ", but provided "
2514  << getOperand(i).getType() << " for operand number " << i;
2515  }
2516  }
2517 
2518  if (functionType.getNumResults() != getNumResults()) {
2519  return emitOpError(
2520  "has incorrect number of results has for callee: expected ")
2521  << functionType.getNumResults() << ", but provided "
2522  << getNumResults();
2523  }
2524 
2525  if (getNumResults() &&
2526  (getResult(0).getType() != functionType.getResult(0))) {
2527  return emitOpError("result type mismatch: expected ")
2528  << functionType.getResult(0) << ", but provided "
2529  << getResult(0).getType();
2530  }
2531 
2532  return success();
2533 }
2534 
2535 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2536  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2537 }
2538 
2539 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2540  return getArguments();
2541 }
2542 
2543 //===----------------------------------------------------------------------===//
2544 // spirv.GLFClampOp
2545 //===----------------------------------------------------------------------===//
2546 
2547 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2548  OperationState &result) {
2549  return parseOneResultSameOperandTypeOp(parser, result);
2550 }
2552 
2553 //===----------------------------------------------------------------------===//
2554 // spirv.GLUClampOp
2555 //===----------------------------------------------------------------------===//
2556 
2557 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2558  OperationState &result) {
2559  return parseOneResultSameOperandTypeOp(parser, result);
2560 }
2562 
2563 //===----------------------------------------------------------------------===//
2564 // spirv.GLSClampOp
2565 //===----------------------------------------------------------------------===//
2566 
2567 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2568  OperationState &result) {
2569  return parseOneResultSameOperandTypeOp(parser, result);
2570 }
2572 
2573 //===----------------------------------------------------------------------===//
2574 // spirv.GLFmaOp
2575 //===----------------------------------------------------------------------===//
2576 
2577 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2578  return parseOneResultSameOperandTypeOp(parser, result);
2579 }
2580 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2581 
2582 //===----------------------------------------------------------------------===//
2583 // spirv.GlobalVariable
2584 //===----------------------------------------------------------------------===//
2585 
2586 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2587  Type type, StringRef name,
2588  unsigned descriptorSet, unsigned binding) {
2589  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2590  state.addAttribute(
2591  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2592  builder.getI32IntegerAttr(descriptorSet));
2593  state.addAttribute(
2594  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2595  builder.getI32IntegerAttr(binding));
2596 }
2597 
2598 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2599  Type type, StringRef name,
2600  spirv::BuiltIn builtin) {
2601  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2602  state.addAttribute(
2603  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2604  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2605 }
2606 
2607 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2608  OperationState &result) {
2609  // Parse variable name.
2610  StringAttr nameAttr;
2611  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2612  result.attributes)) {
2613  return failure();
2614  }
2615 
2616  // Parse optional initializer
2618  FlatSymbolRefAttr initSymbol;
2619  if (parser.parseLParen() ||
2620  parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2621  result.attributes) ||
2622  parser.parseRParen())
2623  return failure();
2624  }
2625 
2626  if (parseVariableDecorations(parser, result)) {
2627  return failure();
2628  }
2629 
2630  Type type;
2631  auto loc = parser.getCurrentLocation();
2632  if (parser.parseColonType(type)) {
2633  return failure();
2634  }
2635  if (!type.isa<spirv::PointerType>()) {
2636  return parser.emitError(loc, "expected spirv.ptr type");
2637  }
2638  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
2639 
2640  return success();
2641 }
2642 
2644  SmallVector<StringRef, 4> elidedAttrs{
2645  spirv::attributeName<spirv::StorageClass>()};
2646 
2647  // Print variable name.
2648  printer << ' ';
2649  printer.printSymbolName(getSymName());
2650  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2651 
2652  // Print optional initializer
2653  if (auto initializer = this->getInitializer()) {
2654  printer << " " << kInitializerAttrName << '(';
2655  printer.printSymbolName(*initializer);
2656  printer << ')';
2657  elidedAttrs.push_back(kInitializerAttrName);
2658  }
2659 
2660  elidedAttrs.push_back(kTypeAttrName);
2661  printVariableDecorations(*this, printer, elidedAttrs);
2662  printer << " : " << getType();
2663 }
2664 
2666  if (!getType().isa<spirv::PointerType>())
2667  return emitOpError("result must be of a !spv.ptr type");
2668 
2669  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2670  // object. It cannot be Generic. It must be the same as the Storage Class
2671  // operand of the Result Type."
2672  // Also, Function storage class is reserved by spirv.Variable.
2673  auto storageClass = this->storageClass();
2674  if (storageClass == spirv::StorageClass::Generic ||
2675  storageClass == spirv::StorageClass::Function) {
2676  return emitOpError("storage class cannot be '")
2677  << stringifyStorageClass(storageClass) << "'";
2678  }
2679 
2680  if (auto init =
2681  (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2683  (*this)->getParentOp(), init.getAttr());
2684  // TODO: Currently only variable initialization with specialization
2685  // constants and other variables is supported. They could be normal
2686  // constants in the module scope as well.
2687  if (!initOp ||
2688  !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2689  return emitOpError("initializer must be result of a "
2690  "spirv.SpecConstant or spirv.GlobalVariable op");
2691  }
2692  }
2693 
2694  return success();
2695 }
2696 
2697 //===----------------------------------------------------------------------===//
2698 // spirv.GroupBroadcast
2699 //===----------------------------------------------------------------------===//
2700 
2702  spirv::Scope scope = getExecutionScope();
2703  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2704  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2705 
2706  if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
2707  if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2708  return emitOpError("localid is a vector and can be with only "
2709  " 2 or 3 components, actual number is ")
2710  << localIdTy.getNumElements();
2711 
2712  return success();
2713 }
2714 
2715 //===----------------------------------------------------------------------===//
2716 // spirv.GroupNonUniformBallotOp
2717 //===----------------------------------------------------------------------===//
2718 
2720  spirv::Scope scope = getExecutionScope();
2721  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2722  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2723 
2724  return success();
2725 }
2726 
2727 //===----------------------------------------------------------------------===//
2728 // spirv.GroupNonUniformBroadcast
2729 //===----------------------------------------------------------------------===//
2730 
2732  spirv::Scope scope = getExecutionScope();
2733  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2734  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2735 
2736  // SPIR-V spec: "Before version 1.5, Id must come from a
2737  // constant instruction.
2738  auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2739  if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2740  targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2741 
2742  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2743  auto *idOp = getId().getDefiningOp();
2744  if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2745  spirv::ReferenceOfOp>(idOp)) // for spec constant
2746  return emitOpError("id must be the result of a constant op");
2747  }
2748 
2749  return success();
2750 }
2751 
2752 //===----------------------------------------------------------------------===//
2753 // spirv.GroupNonUniformShuffle*
2754 //===----------------------------------------------------------------------===//
2755 
2756 template <typename OpTy>
2758  spirv::Scope scope = op.getExecutionScope();
2759  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2760  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2761 
2762  if (op.getOperands().back().getType().isSignedInteger())
2763  return op.emitOpError("second operand must be a singless/unsigned integer");
2764 
2765  return success();
2766 }
2767 
2769  return verifyGroupNonUniformShuffleOp(*this);
2770 }
2772  return verifyGroupNonUniformShuffleOp(*this);
2773 }
2775  return verifyGroupNonUniformShuffleOp(*this);
2776 }
2778  return verifyGroupNonUniformShuffleOp(*this);
2779 }
2780 
2781 //===----------------------------------------------------------------------===//
2782 // spirv.INTEL.SubgroupBlockRead
2783 //===----------------------------------------------------------------------===//
2784 
2785 ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
2786  OperationState &result) {
2787  // Parse the storage class specification
2788  spirv::StorageClass storageClass;
2790  Type elementType;
2791  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2792  parser.parseColon() || parser.parseType(elementType)) {
2793  return failure();
2794  }
2795 
2796  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2797  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2798  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2799 
2800  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
2801  return failure();
2802  }
2803 
2804  result.addTypes(elementType);
2805  return success();
2806 }
2807 
2809  printer << " " << getPtr() << " : " << getType();
2810 }
2811 
2813  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2814  return failure();
2815 
2816  return success();
2817 }
2818 
2819 //===----------------------------------------------------------------------===//
2820 // spirv.INTEL.SubgroupBlockWrite
2821 //===----------------------------------------------------------------------===//
2822 
2823 ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
2824  OperationState &result) {
2825  // Parse the storage class specification
2826  spirv::StorageClass storageClass;
2828  auto loc = parser.getCurrentLocation();
2829  Type elementType;
2830  if (parseEnumStrAttr(storageClass, parser) ||
2831  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2832  parser.parseType(elementType)) {
2833  return failure();
2834  }
2835 
2836  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2837  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2838  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2839 
2840  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2841  result.operands)) {
2842  return failure();
2843  }
2844  return success();
2845 }
2846 
2848  printer << " " << getPtr() << ", " << getValue() << " : "
2849  << getValue().getType();
2850 }
2851 
2853  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2854  return failure();
2855 
2856  return success();
2857 }
2858 
2859 //===----------------------------------------------------------------------===//
2860 // spirv.GroupNonUniformElectOp
2861 //===----------------------------------------------------------------------===//
2862 
2864  spirv::Scope scope = getExecutionScope();
2865  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2866  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2867 
2868  return success();
2869 }
2870 
2871 //===----------------------------------------------------------------------===//
2872 // spirv.GroupNonUniformFAddOp
2873 //===----------------------------------------------------------------------===//
2874 
2876  return verifyGroupNonUniformArithmeticOp(*this);
2877 }
2878 
2879 ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2880  OperationState &result) {
2881  return parseGroupNonUniformArithmeticOp(parser, result);
2882 }
2885 }
2886 
2887 //===----------------------------------------------------------------------===//
2888 // spirv.GroupNonUniformFMaxOp
2889 //===----------------------------------------------------------------------===//
2890 
2892  return verifyGroupNonUniformArithmeticOp(*this);
2893 }
2894 
2895 ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2896  OperationState &result) {
2897  return parseGroupNonUniformArithmeticOp(parser, result);
2898 }
2901 }
2902 
2903 //===----------------------------------------------------------------------===//
2904 // spirv.GroupNonUniformFMinOp
2905 //===----------------------------------------------------------------------===//
2906 
2908  return verifyGroupNonUniformArithmeticOp(*this);
2909 }
2910 
2911 ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2912  OperationState &result) {
2913  return parseGroupNonUniformArithmeticOp(parser, result);
2914 }
2917 }
2918 
2919 //===----------------------------------------------------------------------===//
2920 // spirv.GroupNonUniformFMulOp
2921 //===----------------------------------------------------------------------===//
2922 
2924  return verifyGroupNonUniformArithmeticOp(*this);
2925 }
2926 
2927 ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2928  OperationState &result) {
2929  return parseGroupNonUniformArithmeticOp(parser, result);
2930 }
2933 }
2934 
2935 //===----------------------------------------------------------------------===//
2936 // spirv.GroupNonUniformIAddOp
2937 //===----------------------------------------------------------------------===//
2938 
2940  return verifyGroupNonUniformArithmeticOp(*this);
2941 }
2942 
2943 ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2944  OperationState &result) {
2945  return parseGroupNonUniformArithmeticOp(parser, result);
2946 }
2949 }
2950 
2951 //===----------------------------------------------------------------------===//
2952 // spirv.GroupNonUniformIMulOp
2953 //===----------------------------------------------------------------------===//
2954 
2956  return verifyGroupNonUniformArithmeticOp(*this);
2957 }
2958 
2959 ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2960  OperationState &result) {
2961  return parseGroupNonUniformArithmeticOp(parser, result);
2962 }
2965 }
2966 
2967 //===----------------------------------------------------------------------===//
2968 // spirv.GroupNonUniformSMaxOp
2969 //===----------------------------------------------------------------------===//
2970 
2972  return verifyGroupNonUniformArithmeticOp(*this);
2973 }
2974 
2975 ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2976  OperationState &result) {
2977  return parseGroupNonUniformArithmeticOp(parser, result);
2978 }
2981 }
2982 
2983 //===----------------------------------------------------------------------===//
2984 // spirv.GroupNonUniformSMinOp
2985 //===----------------------------------------------------------------------===//
2986 
2988  return verifyGroupNonUniformArithmeticOp(*this);
2989 }
2990 
2991 ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
2992  OperationState &result) {
2993  return parseGroupNonUniformArithmeticOp(parser, result);
2994 }
2997 }
2998 
2999 //===----------------------------------------------------------------------===//
3000 // spirv.GroupNonUniformUMaxOp
3001 //===----------------------------------------------------------------------===//
3002 
3004  return verifyGroupNonUniformArithmeticOp(*this);
3005 }
3006 
3007 ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
3008  OperationState &result) {
3009  return parseGroupNonUniformArithmeticOp(parser, result);
3010 }
3013 }
3014 
3015 //===----------------------------------------------------------------------===//
3016 // spirv.GroupNonUniformUMinOp
3017 //===----------------------------------------------------------------------===//
3018 
3020  return verifyGroupNonUniformArithmeticOp(*this);
3021 }
3022 
3023 ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
3024  OperationState &result) {
3025  return parseGroupNonUniformArithmeticOp(parser, result);
3026 }
3029 }
3030 
3031 //===----------------------------------------------------------------------===//
3032 // spirv.IAddCarryOp
3033 //===----------------------------------------------------------------------===//
3034 
3037 }
3038 
3039 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
3040  OperationState &result) {
3042 }
3043 
3044 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
3045  ::printArithmeticExtendedBinaryOp(*this, printer);
3046 }
3047 
3048 //===----------------------------------------------------------------------===//
3049 // spirv.ISubBorrowOp
3050 //===----------------------------------------------------------------------===//
3051 
3054 }
3055 
3056 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
3057  OperationState &result) {
3059 }
3060 
3062  ::printArithmeticExtendedBinaryOp(*this, printer);
3063 }
3064 
3065 //===----------------------------------------------------------------------===//
3066 // spirv.SMulExtended
3067 //===----------------------------------------------------------------------===//
3068 
3071 }
3072 
3073 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
3074  OperationState &result) {
3076 }
3077 
3079  ::printArithmeticExtendedBinaryOp(*this, printer);
3080 }
3081 
3082 //===----------------------------------------------------------------------===//
3083 // spirv.UMulExtended
3084 //===----------------------------------------------------------------------===//
3085 
3088 }
3089 
3090 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
3091  OperationState &result) {
3093 }
3094 
3096  ::printArithmeticExtendedBinaryOp(*this, printer);
3097 }
3098 
3099 //===----------------------------------------------------------------------===//
3100 // spirv.LoadOp
3101 //===----------------------------------------------------------------------===//
3102 
3103 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
3104  Value basePtr, MemoryAccessAttr memoryAccess,
3105  IntegerAttr alignment) {
3106  auto ptrType = basePtr.getType().cast<spirv::PointerType>();
3107  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3108  alignment);
3109 }
3110 
3111 ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
3112  // Parse the storage class specification
3113  spirv::StorageClass storageClass;
3115  Type elementType;
3116  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
3117  parseMemoryAccessAttributes(parser, result) ||
3118  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
3119  parser.parseType(elementType)) {
3120  return failure();
3121  }
3122 
3123  auto ptrType = spirv::PointerType::get(elementType, storageClass);
3124  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
3125  return failure();
3126  }
3127 
3128  result.addTypes(elementType);
3129  return success();
3130 }
3131 
3132 void spirv::LoadOp::print(OpAsmPrinter &printer) {
3133  SmallVector<StringRef, 4> elidedAttrs;
3134  StringRef sc = stringifyStorageClass(
3135  getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3136  printer << " \"" << sc << "\" " << getPtr();
3137 
3138  printMemoryAccessAttribute(*this, printer, elidedAttrs);
3139 
3140  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3141  printer << " : " << getType();
3142 }
3143 
3145  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
3146  // type with fixed size; i.e., it cannot be, nor include, any
3147  // OpTypeRuntimeArray types."
3148  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
3149  return failure();
3150  }
3151  return verifyMemoryAccessAttribute(*this);
3152 }
3153 
3154 //===----------------------------------------------------------------------===//
3155 // spirv.mlir.loop
3156 //===----------------------------------------------------------------------===//
3157 
3158 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
3159  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
3161  state.addRegion();
3162 }
3163 
3164 ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
3165  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3166  result))
3167  return failure();
3168  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3169 }
3170 
3171 void spirv::LoopOp::print(OpAsmPrinter &printer) {
3172  auto control = getLoopControl();
3173  if (control != spirv::LoopControl::None)
3174  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
3175  printer << ' ';
3176  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3177  /*printBlockTerminators=*/true);
3178 }
3179 
3180 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
3181 /// given `dstBlock`.
3182 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
3183  // Check that there is only one op in the `srcBlock`.
3184  if (!llvm::hasSingleElement(srcBlock))
3185  return false;
3186 
3187  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
3188  return branchOp && branchOp.getSuccessor() == &dstBlock;
3189 }
3190 
3191 LogicalResult spirv::LoopOp::verifyRegions() {
3192  auto *op = getOperation();
3193 
3194  // We need to verify that the blocks follow the following layout:
3195  //
3196  // +-------------+
3197  // | entry block |
3198  // +-------------+
3199  // |
3200  // v
3201  // +-------------+
3202  // | loop header | <-----+
3203  // +-------------+ |
3204  // |
3205  // ... |
3206  // \ | / |
3207  // v |
3208  // +---------------+ |
3209  // | loop continue | -----+
3210  // +---------------+
3211  //
3212  // ...
3213  // \ | /
3214  // v
3215  // +-------------+
3216  // | merge block |
3217  // +-------------+
3218 
3219  auto &region = op->getRegion(0);
3220  // Allow empty region as a degenerated case, which can come from
3221  // optimizations.
3222  if (region.empty())
3223  return success();
3224 
3225  // The last block is the merge block.
3226  Block &merge = region.back();
3227  if (!isMergeBlock(merge))
3228  return emitOpError("last block must be the merge block with only one "
3229  "'spirv.mlir.merge' op");
3230 
3231  if (std::next(region.begin()) == region.end())
3232  return emitOpError(
3233  "must have an entry block branching to the loop header block");
3234  // The first block is the entry block.
3235  Block &entry = region.front();
3236 
3237  if (std::next(region.begin(), 2) == region.end())
3238  return emitOpError(
3239  "must have a loop header block branched from the entry block");
3240  // The second block is the loop header block.
3241  Block &header = *std::next(region.begin(), 1);
3242 
3243  if (!hasOneBranchOpTo(entry, header))
3244  return emitOpError(
3245  "entry block must only have one 'spirv.Branch' op to the second block");
3246 
3247  if (std::next(region.begin(), 3) == region.end())
3248  return emitOpError(
3249  "requires a loop continue block branching to the loop header block");
3250  // The second to last block is the loop continue block.
3251  Block &cont = *std::prev(region.end(), 2);
3252 
3253  // Make sure that we have a branch from the loop continue block to the loop
3254  // header block.
3255  if (llvm::none_of(
3256  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3257  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3258  return emitOpError("second to last block must be the loop continue "
3259  "block that branches to the loop header block");
3260 
3261  // Make sure that no other blocks (except the entry and loop continue block)
3262  // branches to the loop header block.
3263  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3264  std::prev(region.end(), 2))) {
3265  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3266  if (block.getSuccessor(i) == &header) {
3267  return emitOpError("can only have the entry and loop continue "
3268  "block branching to the loop header block");
3269  }
3270  }
3271  }
3272 
3273  return success();
3274 }
3275 
3276 Block *spirv::LoopOp::getEntryBlock() {
3277  assert(!getBody().empty() && "op region should not be empty!");
3278  return &getBody().front();
3279 }
3280 
3281 Block *spirv::LoopOp::getHeaderBlock() {
3282  assert(!getBody().empty() && "op region should not be empty!");
3283  // The second block is the loop header block.
3284  return &*std::next(getBody().begin());
3285 }
3286 
3287 Block *spirv::LoopOp::getContinueBlock() {
3288  assert(!getBody().empty() && "op region should not be empty!");
3289  // The second to last block is the loop continue block.
3290  return &*std::prev(getBody().end(), 2);
3291 }
3292 
3293 Block *spirv::LoopOp::getMergeBlock() {
3294  assert(!getBody().empty() && "op region should not be empty!");
3295  // The last block is the loop merge block.
3296  return &getBody().back();
3297 }
3298 
3299 void spirv::LoopOp::addEntryAndMergeBlock() {
3300  assert(getBody().empty() && "entry and merge block already exist");
3301  getBody().push_back(new Block());
3302  auto *mergeBlock = new Block();
3303  getBody().push_back(mergeBlock);
3304  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3305 
3306  // Add a spirv.mlir.merge op into the merge block.
3307  builder.create<spirv::MergeOp>(getLoc());
3308 }
3309 
3310 //===----------------------------------------------------------------------===//
3311 // spirv.MemoryBarrierOp
3312 //===----------------------------------------------------------------------===//
3313 
3315  return verifyMemorySemantics(getOperation(), getMemorySemantics());
3316 }
3317 
3318 //===----------------------------------------------------------------------===//
3319 // spirv.mlir.merge
3320 //===----------------------------------------------------------------------===//
3321 
3323  auto *parentOp = (*this)->getParentOp();
3324  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3325  return emitOpError(
3326  "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
3327 
3328  // TODO: This check should be done in `verifyRegions` of parent op.
3329  Block &parentLastBlock = (*this)->getParentRegion()->back();
3330  if (getOperation() != parentLastBlock.getTerminator())
3331  return emitOpError("can only be used in the last block of "
3332  "'spirv.mlir.selection' or 'spirv.mlir.loop'");
3333  return success();
3334 }
3335 
3336 //===----------------------------------------------------------------------===//
3337 // spirv.module
3338 //===----------------------------------------------------------------------===//
3339 
3340 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3341  std::optional<StringRef> name) {
3342  OpBuilder::InsertionGuard guard(builder);
3343  builder.createBlock(state.addRegion());
3344  if (name) {
3346  builder.getStringAttr(*name));
3347  }
3348 }
3349 
3350 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3351  spirv::AddressingModel addressingModel,
3352  spirv::MemoryModel memoryModel,
3353  std::optional<VerCapExtAttr> vceTriple,
3354  std::optional<StringRef> name) {
3355  state.addAttribute(
3356  "addressing_model",
3357  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
3358  state.addAttribute("memory_model",
3359  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
3360  OpBuilder::InsertionGuard guard(builder);
3361  builder.createBlock(state.addRegion());
3362  if (vceTriple)
3363  state.addAttribute(getVCETripleAttrName(), *vceTriple);
3364  if (name)
3366  builder.getStringAttr(*name));
3367 }
3368 
3369 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
3370  OperationState &result) {
3371  Region *body = result.addRegion();
3372 
3373  // If the name is present, parse it.
3374  StringAttr nameAttr;
3375  (void)parser.parseOptionalSymbolName(
3376  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
3377 
3378  // Parse attributes
3379  spirv::AddressingModel addrModel;
3380  spirv::MemoryModel memoryModel;
3381  if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3382  result) ||
3383  ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3384  result))
3385  return failure();
3386 
3387  if (succeeded(parser.parseOptionalKeyword("requires"))) {
3388  spirv::VerCapExtAttr vceTriple;
3389  if (parser.parseAttribute(vceTriple,
3390  spirv::ModuleOp::getVCETripleAttrName(),
3391  result.attributes))
3392  return failure();
3393  }
3394 
3395  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3396  parser.parseRegion(*body, /*arguments=*/{}))
3397  return failure();
3398 
3399  // Make sure we have at least one block.
3400  if (body->empty())
3401  body->push_back(new Block());
3402 
3403  return success();
3404 }
3405 
3406 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3407  if (std::optional<StringRef> name = getName()) {
3408  printer << ' ';
3409  printer.printSymbolName(*name);
3410  }
3411 
3412  SmallVector<StringRef, 2> elidedAttrs;
3413 
3414  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
3415  << spirv::stringifyMemoryModel(getMemoryModel());
3416  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3417  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3418  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3420 
3421  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
3422  printer << " requires " << *triple;
3423  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3424  }
3425 
3426  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3427  printer << ' ';
3428  printer.printRegion(getRegion());
3429 }
3430 
3431 LogicalResult spirv::ModuleOp::verifyRegions() {
3432  Dialect *dialect = (*this)->getDialect();
3434  entryPoints;
3435  mlir::SymbolTable table(*this);
3436 
3437  for (auto &op : *getBody()) {
3438  if (op.getDialect() != dialect)
3439  return op.emitError("'spirv.module' can only contain spirv.* ops");
3440 
3441  // For EntryPoint op, check that the function and execution model is not
3442  // duplicated in EntryPointOps. Also verify that the interface specified
3443  // comes from globalVariables here to make this check cheaper.
3444  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3445  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
3446  if (!funcOp) {
3447  return entryPointOp.emitError("function '")
3448  << entryPointOp.getFn() << "' not found in 'spirv.module'";
3449  }
3450  if (auto interface = entryPointOp.getInterface()) {
3451  for (Attribute varRef : interface) {
3452  auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3453  if (!varSymRef) {
3454  return entryPointOp.emitError(
3455  "expected symbol reference for interface "
3456  "specification instead of '")
3457  << varRef;
3458  }
3459  auto variableOp =
3460  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3461  if (!variableOp) {
3462  return entryPointOp.emitError("expected spirv.GlobalVariable "
3463  "symbol reference instead of'")
3464  << varSymRef << "'";
3465  }
3466  }
3467  }
3468 
3469  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3470  funcOp, entryPointOp.getExecutionModel());
3471  auto entryPtIt = entryPoints.find(key);
3472  if (entryPtIt != entryPoints.end()) {
3473  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3474  }
3475  entryPoints[key] = entryPointOp;
3476  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3477  if (funcOp.isExternal())
3478  return op.emitError("'spirv.module' cannot contain external functions");
3479 
3480  // TODO: move this check to spirv.func.
3481  for (auto &block : funcOp)
3482  for (auto &op : block) {
3483  if (op.getDialect() != dialect)
3484  return op.emitError(
3485  "functions in 'spirv.module' can only contain spirv.* ops");
3486  }
3487  }
3488  }
3489 
3490  return success();
3491 }
3492 
3493 //===----------------------------------------------------------------------===//
3494 // spirv.mlir.referenceof
3495 //===----------------------------------------------------------------------===//
3496 
3498  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3499  (*this)->getParentOp(), getSpecConstAttr());
3500  Type constType;
3501 
3502  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3503  if (specConstOp)
3504  constType = specConstOp.getDefaultValue().getType();
3505 
3506  auto specConstCompositeOp =
3507  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3508  if (specConstCompositeOp)
3509  constType = specConstCompositeOp.getType();
3510 
3511  if (!specConstOp && !specConstCompositeOp)
3512  return emitOpError(
3513  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
3514 
3515  if (getReference().getType() != constType)
3516  return emitOpError("result type mismatch with the referenced "
3517  "specialization constant's type");
3518 
3519  return success();
3520 }
3521 
3522 //===----------------------------------------------------------------------===//
3523 // spirv.Return
3524 //===----------------------------------------------------------------------===//
3525 
3527  // Verification is performed in spirv.func op.
3528  return success();
3529 }
3530 
3531 //===----------------------------------------------------------------------===//
3532 // spirv.ReturnValue
3533 //===----------------------------------------------------------------------===//
3534 
3536  // Verification is performed in spirv.func op.
3537  return success();
3538 }
3539 
3540 //===----------------------------------------------------------------------===//
3541 // spirv.Select
3542 //===----------------------------------------------------------------------===//
3543 
3545  if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
3546  auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
3547  if (!resultVectorTy) {
3548  return emitOpError("result expected to be of vector type when "
3549  "condition is of vector type");
3550  }
3551  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3552  return emitOpError("result should have the same number of elements as "
3553  "the condition when condition is of vector type");
3554  }
3555  }
3556  return success();
3557 }
3558 
3559 //===----------------------------------------------------------------------===//
3560 // spirv.mlir.selection
3561 //===----------------------------------------------------------------------===//
3562 
3563 ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3564  OperationState &result) {
3565  if (parseControlAttribute<spirv::SelectionControlAttr,
3566  spirv::SelectionControl>(parser, result))
3567  return failure();
3568  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3569 }
3570 
3571 void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3572  auto control = getSelectionControl();
3573  if (control != spirv::SelectionControl::None)
3574  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3575  printer << ' ';
3576  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3577  /*printBlockTerminators=*/true);
3578 }
3579 
3580 LogicalResult spirv::SelectionOp::verifyRegions() {
3581  auto *op = getOperation();
3582 
3583  // We need to verify that the blocks follow the following layout:
3584  //
3585  // +--------------+
3586  // | header block |
3587  // +--------------+
3588  // / | \
3589  // ...
3590  //
3591  //
3592  // +---------+ +---------+ +---------+
3593  // | case #0 | | case #1 | | case #2 | ...
3594  // +---------+ +---------+ +---------+
3595  //
3596  //
3597  // ...
3598  // \ | /
3599  // v
3600  // +-------------+
3601  // | merge block |
3602  // +-------------+
3603 
3604  auto &region = op->getRegion(0);
3605  // Allow empty region as a degenerated case, which can come from
3606  // optimizations.
3607  if (region.empty())
3608  return success();
3609 
3610  // The last block is the merge block.
3611  if (!isMergeBlock(region.back()))
3612  return emitOpError("last block must be the merge block with only one "
3613  "'spirv.mlir.merge' op");
3614 
3615  if (std::next(region.begin()) == region.end())
3616  return emitOpError("must have a selection header block");
3617 
3618  return success();
3619 }
3620 
3621 Block *spirv::SelectionOp::getHeaderBlock() {
3622  assert(!getBody().empty() && "op region should not be empty!");
3623  // The first block is the loop header block.
3624  return &getBody().front();
3625 }
3626 
3627 Block *spirv::SelectionOp::getMergeBlock() {
3628  assert(!getBody().empty() && "op region should not be empty!");
3629  // The last block is the loop merge block.
3630  return &getBody().back();
3631 }
3632 
3633 void spirv::SelectionOp::addMergeBlock() {
3634  assert(getBody().empty() && "entry and merge block already exist");
3635  auto *mergeBlock = new Block();
3636  getBody().push_back(mergeBlock);
3637  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3638 
3639  // Add a spirv.mlir.merge op into the merge block.
3640  builder.create<spirv::MergeOp>(getLoc());
3641 }
3642 
3643 spirv::SelectionOp spirv::SelectionOp::createIfThen(
3644  Location loc, Value condition,
3645  function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3646  auto selectionOp =
3647  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3648 
3649  selectionOp.addMergeBlock();
3650  Block *mergeBlock = selectionOp.getMergeBlock();
3651  Block *thenBlock = nullptr;
3652 
3653  // Build the "then" block.
3654  {
3655  OpBuilder::InsertionGuard guard(builder);
3656  thenBlock = builder.createBlock(mergeBlock);
3657  thenBody(builder);
3658  builder.create<spirv::BranchOp>(loc, mergeBlock);
3659  }
3660 
3661  // Build the header block.
3662  {
3663  OpBuilder::InsertionGuard guard(builder);
3664  builder.createBlock(thenBlock);
3665  builder.create<spirv::BranchConditionalOp>(
3666  loc, condition, thenBlock,
3667  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3668  /*falseArguments=*/ArrayRef<Value>());
3669  }
3670 
3671  return selectionOp;
3672 }
3673 
3674 //===----------------------------------------------------------------------===//
3675 // spirv.SpecConstant
3676 //===----------------------------------------------------------------------===//
3677 
3678 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3679  OperationState &result) {
3680  StringAttr nameAttr;
3681  Attribute valueAttr;
3682 
3683  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3684  result.attributes))
3685  return failure();
3686 
3687  // Parse optional spec_id.
3689  IntegerAttr specIdAttr;
3690  if (parser.parseLParen() ||
3691  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
3692  parser.parseRParen())
3693  return failure();
3694  }
3695 
3696  if (parser.parseEqual() ||
3697  parser.parseAttribute(valueAttr, kDefaultValueAttrName,
3698  result.attributes))
3699  return failure();
3700 
3701  return success();
3702 }
3703 
3705  printer << ' ';
3706  printer.printSymbolName(getSymName());
3707  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3708  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3709  printer << " = " << getDefaultValue();
3710 }
3711 
3713  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3714  if (specID.getValue().isNegative())
3715  return emitOpError("SpecId cannot be negative");
3716 
3717  auto value = getDefaultValue();
3718  if (value.isa<IntegerAttr, FloatAttr>()) {
3719  // Make sure bitwidth is allowed.
3720  if (!value.getType().isa<spirv::SPIRVType>())
3721  return emitOpError("default value bitwidth disallowed");
3722  return success();
3723  }
3724  return emitOpError(
3725  "default value can only be a bool, integer, or float scalar");
3726 }
3727 
3728 //===----------------------------------------------------------------------===//
3729 // spirv.StoreOp
3730 //===----------------------------------------------------------------------===//
3731 
3732 ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
3733  // Parse the storage class specification
3734  spirv::StorageClass storageClass;
3736  auto loc = parser.getCurrentLocation();
3737  Type elementType;
3738  if (parseEnumStrAttr(storageClass, parser) ||
3739  parser.parseOperandList(operandInfo, 2) ||
3740  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3741  parser.parseType(elementType)) {
3742  return failure();
3743  }
3744 
3745  auto ptrType = spirv::PointerType::get(elementType, storageClass);
3746  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3747  result.operands)) {
3748  return failure();
3749  }
3750  return success();
3751 }
3752 
3753 void spirv::StoreOp::print(OpAsmPrinter &printer) {
3754  SmallVector<StringRef, 4> elidedAttrs;
3755  StringRef sc = stringifyStorageClass(
3756  getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3757  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
3758 
3759  printMemoryAccessAttribute(*this, printer, elidedAttrs);
3760 
3761  printer << " : " << getValue().getType();
3762  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3763 }
3764 
3766  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3767  // OpTypePointer whose Type operand is the same as the type of Object."
3768  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
3769  return failure();
3770  return verifyMemoryAccessAttribute(*this);
3771 }
3772 
3773 //===----------------------------------------------------------------------===//
3774 // spirv.Unreachable
3775 //===----------------------------------------------------------------------===//
3776 
3778  auto *block = (*this)->getBlock();
3779  // Fast track: if this is in entry block, its invalid. Otherwise, if no
3780  // predecessors, it's valid.
3781  if (block->isEntryBlock())
3782  return emitOpError("cannot be used in reachable block");
3783  if (block->hasNoPredecessors())
3784  return success();
3785 
3786  // TODO: further verification needs to analyze reachability from
3787  // the entry block.
3788 
3789  return success();
3790 }
3791 
3792 //===----------------------------------------------------------------------===//
3793 // spirv.Variable
3794 //===----------------------------------------------------------------------===//
3795 
3796 ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3797  OperationState &result) {
3798  // Parse optional initializer
3799  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
3800  if (succeeded(parser.parseOptionalKeyword("init"))) {
3801  initInfo = OpAsmParser::UnresolvedOperand();
3802  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3803  parser.parseRParen())
3804  return failure();
3805  }
3806 
3807  if (parseVariableDecorations(parser, result)) {
3808  return failure();
3809  }
3810 
3811  // Parse result pointer type
3812  Type type;
3813  if (parser.parseColon())
3814  return failure();
3815  auto loc = parser.getCurrentLocation();
3816  if (parser.parseType(type))
3817  return failure();
3818 
3819  auto ptrType = type.dyn_cast<spirv::PointerType>();
3820  if (!ptrType)
3821  return parser.emitError(loc, "expected spirv.ptr type");
3822  result.addTypes(ptrType);
3823 
3824  // Resolve the initializer operand
3825  if (initInfo) {
3826  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3827  result.operands))
3828  return failure();
3829  }
3830 
3831  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
3832  ptrType.getStorageClass());
3833  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3834 
3835  return success();
3836 }
3837 
3838 void spirv::VariableOp::print(OpAsmPrinter &printer) {
3839  SmallVector<StringRef, 4> elidedAttrs{
3840  spirv::attributeName<spirv::StorageClass>()};
3841  // Print optional initializer
3842  if (getNumOperands() != 0)
3843  printer << " init(" << getInitializer() << ")";
3844 
3845  printVariableDecorations(*this, printer, elidedAttrs);
3846  printer << " : " << getType();
3847 }
3848 
3850  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3851  // object. It cannot be Generic. It must be the same as the Storage Class
3852  // operand of the Result Type."
3853  if (getStorageClass() != spirv::StorageClass::Function) {
3854  return emitOpError(
3855  "can only be used to model function-level variables. Use "
3856  "spirv.GlobalVariable for module-level variables.");
3857  }
3858 
3859  auto pointerType = getPointer().getType().cast<spirv::PointerType>();
3860  if (getStorageClass() != pointerType.getStorageClass())
3861  return emitOpError(
3862  "storage class must match result pointer's storage class");
3863 
3864  if (getNumOperands() != 0) {
3865  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3866  // a global (module scope) OpVariable instruction".
3867  auto *initOp = getOperand(0).getDefiningOp();
3868  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3869  spirv::ReferenceOfOp, // for spec constant
3870  spirv::AddressOfOp>(initOp))
3871  return emitOpError("initializer must be the result of a "
3872  "constant or spirv.GlobalVariable op");
3873  }
3874 
3875  // TODO: generate these strings using ODS.
3876  auto *op = getOperation();
3877  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3878  stringifyDecoration(spirv::Decoration::DescriptorSet));
3879  auto bindingName = llvm::convertToSnakeFromCamelCase(
3880  stringifyDecoration(spirv::Decoration::Binding));
3881  auto builtInName = llvm::convertToSnakeFromCamelCase(
3882  stringifyDecoration(spirv::Decoration::BuiltIn));
3883 
3884  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3885  if (op->getAttr(attr))
3886  return emitOpError("cannot have '")
3887  << attr << "' attribute (only allowed in spirv.GlobalVariable)";
3888  }
3889 
3890  return success();
3891 }
3892 
3893 //===----------------------------------------------------------------------===//
3894 // spirv.VectorShuffle
3895 //===----------------------------------------------------------------------===//
3896 
3898  VectorType resultType = getType().cast<VectorType>();
3899 
3900  size_t numResultElements = resultType.getNumElements();
3901  if (numResultElements != getComponents().size())
3902  return emitOpError("result type element count (")
3903  << numResultElements
3904  << ") mismatch with the number of component selectors ("
3905  << getComponents().size() << ")";
3906 
3907  size_t totalSrcElements =
3908  getVector1().getType().cast<VectorType>().getNumElements() +
3909  getVector2().getType().cast<VectorType>().getNumElements();
3910 
3911  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
3912  uint32_t index = selector.getZExtValue();
3913  if (index >= totalSrcElements &&
3914  index != std::numeric_limits<uint32_t>().max())
3915  return emitOpError("component selector ")
3916  << index << " out of range: expected to be in [0, "
3917  << totalSrcElements << ") or 0xffffffff";
3918  }
3919  return success();
3920 }
3921 
3922 //===----------------------------------------------------------------------===//
3923 // spirv.NV.CooperativeMatrixLoad
3924 //===----------------------------------------------------------------------===//
3925 
3926 ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
3927  OperationState &result) {
3929  Type strideType = parser.getBuilder().getIntegerType(32);
3930  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3931  Type ptrType;
3932  Type elementType;
3933  if (parser.parseOperandList(operandInfo, 3) ||
3934  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3935  parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3936  return failure();
3937  }
3938  if (parser.resolveOperands(operandInfo,
3939  {ptrType, strideType, columnMajorType},
3940  parser.getNameLoc(), result.operands)) {
3941  return failure();
3942  }
3943 
3944  result.addTypes(elementType);
3945  return success();
3946 }
3947 
3949  printer << " " << getPointer() << ", " << getStride() << ", "
3950  << getColumnmajor();
3951  // Print optional memory access attribute.
3952  if (auto memAccess = getMemoryAccess())
3953  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3954  printer << " : " << getPointer().getType() << " as " << getType();
3955 }
3956 
3958  Type coopMatrix) {
3959  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3960  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3961  return op->emitError(
3962  "Pointer must point to a scalar or vector type but provided ")
3963  << pointeeType;
3964  spirv::StorageClass storage =
3965  pointer.cast<spirv::PointerType>().getStorageClass();
3966  if (storage != spirv::StorageClass::Workgroup &&
3967  storage != spirv::StorageClass::StorageBuffer &&
3968  storage != spirv::StorageClass::PhysicalStorageBuffer)
3969  return op->emitError(
3970  "Pointer storage class must be Workgroup, StorageBuffer or "
3971  "PhysicalStorageBufferEXT but provided ")
3972  << stringifyStorageClass(storage);
3973  return success();
3974 }
3975 
3977  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
3978  getResult().getType());
3979 }
3980 
3981 //===----------------------------------------------------------------------===//
3982 // spirv.NV.CooperativeMatrixStore
3983 //===----------------------------------------------------------------------===//
3984 
3985 ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
3986  OperationState &result) {
3988  Type strideType = parser.getBuilder().getIntegerType(32);
3989  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3990  Type ptrType;
3991  Type elementType;
3992  if (parser.parseOperandList(operandInfo, 4) ||
3993  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3994  parser.parseType(ptrType) || parser.parseComma() ||
3995  parser.parseType(elementType)) {
3996  return failure();
3997  }
3998  if (parser.resolveOperands(
3999  operandInfo, {ptrType, elementType, strideType, columnMajorType},
4000  parser.getNameLoc(), result.operands)) {
4001  return failure();
4002  }
4003 
4004  return success();
4005 }
4006 
4008  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
4009  << ", " << getColumnmajor();
4010  // Print optional memory access attribute.
4011  if (auto memAccess = getMemoryAccess())
4012  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
4013  printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
4014 }
4015 
4017  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
4018  getObject().getType());
4019 }
4020 
4021 //===----------------------------------------------------------------------===//
4022 // spirv.NV.CooperativeMatrixMulAdd
4023 //===----------------------------------------------------------------------===//
4024 
4025 static LogicalResult
4026 verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
4027  if (op.getC().getType() != op.getResult().getType())
4028  return op.emitOpError("result and third operand must have the same type");
4029  auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
4030  auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
4031  auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
4032  auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
4033  if (typeA.getRows() != typeR.getRows() ||
4034  typeA.getColumns() != typeB.getRows() ||
4035  typeB.getColumns() != typeR.getColumns())
4036  return op.emitOpError("matrix size must match");
4037  if (typeR.getScope() != typeA.getScope() ||
4038  typeR.getScope() != typeB.getScope() ||
4039  typeR.getScope() != typeC.getScope())
4040  return op.emitOpError("matrix scope must match");
4041  auto elementTypeA = typeA.getElementType();
4042  auto elementTypeB = typeB.getElementType();
4043  if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
4044  if (elementTypeA.cast<IntegerType>().getWidth() !=
4045  elementTypeB.cast<IntegerType>().getWidth())
4046  return op.emitOpError(
4047  "matrix A and B integer element types must be the same bit width");
4048  } else if (elementTypeA != elementTypeB) {
4049  return op.emitOpError(
4050  "matrix A and B non-integer element types must match");
4051  }
4052  if (typeR.getElementType() != typeC.getElementType())
4053  return op.emitOpError("matrix accumulator element type must match");
4054  return success();
4055 }
4056 
4058  return verifyCoopMatrixMulAdd(*this);
4059 }
4060 
4061 static LogicalResult
4063  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
4064  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
4065  return op->emitError(
4066  "Pointer must point to a scalar or vector type but provided ")
4067  << pointeeType;
4068  spirv::StorageClass storage =
4069  pointer.cast<spirv::PointerType>().getStorageClass();
4070  if (storage != spirv::StorageClass::Workgroup &&
4071  storage != spirv::StorageClass::CrossWorkgroup &&
4072  storage != spirv::StorageClass::UniformConstant &&
4073  storage != spirv::StorageClass::Generic)
4074  return op->emitError("Pointer storage class must be Workgroup or "
4075  "CrossWorkgroup but provided ")
4076  << stringifyStorageClass(storage);
4077  return success();
4078 }
4079 
4080 //===----------------------------------------------------------------------===//
4081 // spirv.INTEL.JointMatrixLoad
4082 //===----------------------------------------------------------------------===//
4083 
4085  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4086  getResult().getType());
4087 }
4088 
4089 //===----------------------------------------------------------------------===//
4090 // spirv.INTEL.JointMatrixStore
4091 //===----------------------------------------------------------------------===//
4092 
4094  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4095  getObject().getType());
4096 }
4097 
4098 //===----------------------------------------------------------------------===//
4099 // spirv.INTEL.JointMatrixMad
4100 //===----------------------------------------------------------------------===//
4101 
4102 static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
4103  if (op.getC().getType() != op.getResult().getType())
4104  return op.emitOpError("result and third operand must have the same type");
4105  auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
4106  auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
4107  auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
4108  auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
4109  if (typeA.getRows() != typeR.getRows() ||
4110  typeA.getColumns() != typeB.getRows() ||
4111  typeB.getColumns() != typeR.getColumns())
4112  return op.emitOpError("matrix size must match");
4113  if (typeR.getScope() != typeA.getScope() ||
4114  typeR.getScope() != typeB.getScope() ||
4115  typeR.getScope() != typeC.getScope())
4116  return op.emitOpError("matrix scope must match");
4117  if (typeA.getElementType() != typeB.getElementType() ||
4118  typeR.getElementType() != typeC.getElementType())
4119  return op.emitOpError("matrix element type must match");
4120  return success();
4121 }
4122 
4124  return verifyJointMatrixMad(*this);
4125 }
4126 
4127 //===----------------------------------------------------------------------===//
4128 // spirv.MatrixTimesScalar
4129 //===----------------------------------------------------------------------===//
4130 
4132  if (auto inputCoopmat =
4133  getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
4134  if (inputCoopmat.getElementType() != getScalar().getType())
4135  return emitError("input matrix components' type and scaling value must "
4136  "have the same type");
4137  return success();
4138  }
4139 
4140  // Check that the scalar type is the same as the matrix element type.
4141  auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4142  if (getScalar().getType() != inputMatrix.getElementType())
4143  return emitError("input matrix components' type and scaling value must "
4144  "have the same type");
4145 
4146  return success();
4147 }
4148 
4149 //===----------------------------------------------------------------------===//
4150 // spirv.CopyMemory
4151 //===----------------------------------------------------------------------===//
4152 
4154  printer << ' ';
4155 
4156  StringRef targetStorageClass = stringifyStorageClass(
4157  getTarget().getType().cast<spirv::PointerType>().getStorageClass());
4158  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
4159 
4160  StringRef sourceStorageClass = stringifyStorageClass(
4161  getSource().getType().cast<spirv::PointerType>().getStorageClass());
4162  printer << " \"" << sourceStorageClass << "\" " << getSource();
4163 
4164  SmallVector<StringRef, 4> elidedAttrs;
4165  printMemoryAccessAttribute(*this, printer, elidedAttrs);
4166  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
4167  getSourceMemoryAccess(),
4168  getSourceAlignment());
4169 
4170  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4171 
4172  Type pointeeType =
4173  getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4174  printer << " : " << pointeeType;
4175 }
4176 
4177 ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
4178  OperationState &result) {
4179  spirv::StorageClass targetStorageClass;
4180  OpAsmParser::UnresolvedOperand targetPtrInfo;
4181 
4182  spirv::StorageClass sourceStorageClass;
4183  OpAsmParser::UnresolvedOperand sourcePtrInfo;
4184 
4185  Type elementType;
4186 
4187  if (parseEnumStrAttr(targetStorageClass, parser) ||
4188  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
4189  parseEnumStrAttr(sourceStorageClass, parser) ||
4190  parser.parseOperand(sourcePtrInfo) ||
4191  parseMemoryAccessAttributes(parser, result)) {
4192  return failure();
4193  }
4194 
4195  if (!parser.parseOptionalComma()) {
4196  // Parse 2nd memory access attributes.
4197  if (parseSourceMemoryAccessAttributes(parser, result)) {
4198  return failure();
4199  }
4200  }
4201 
4202  if (parser.parseColon() || parser.parseType(elementType))
4203  return failure();
4204 
4205  if (parser.parseOptionalAttrDict(result.attributes))
4206  return failure();
4207 
4208  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
4209  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
4210 
4211  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
4212  parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
4213  return failure();
4214  }
4215 
4216  return success();
4217 }
4218 
4220  Type targetType =
4221  getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4222 
4223  Type sourceType =
4224  getSource().getType().cast<spirv::PointerType>().getPointeeType();
4225 
4226  if (targetType != sourceType)
4227  return emitOpError("both operands must be pointers to the same type");
4228 
4229  if (failed(verifyMemoryAccessAttribute(*this)))
4230  return failure();
4231 
4232  // TODO - According to the spec:
4233  //
4234  // If two masks are present, the first applies to Target and cannot include
4235  // MakePointerVisible, and the second applies to Source and cannot include
4236  // MakePointerAvailable.
4237  //
4238  // Add such verification here.
4239 
4240  return verifySourceMemoryAccessAttribute(*this);
4241 }
4242 
4243 //===----------------------------------------------------------------------===//
4244 // spirv.Transpose
4245 //===----------------------------------------------------------------------===//
4246 
4248  auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4249  auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4250 
4251  // Verify that the input and output matrices have correct shapes.
4252  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
4253  return emitError("input matrix rows count must be equal to "
4254  "output matrix columns count");
4255 
4256  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
4257  return emitError("input matrix columns count must be equal to "
4258  "output matrix rows count");
4259 
4260  // Verify that the input and output matrices have the same component type
4261  if (inputMatrix.getElementType() != resultMatrix.getElementType())
4262  return emitError("input and output matrices must have the same "
4263  "component type");
4264 
4265  return success();
4266 }
4267 
4268 //===----------------------------------------------------------------------===//
4269 // spirv.MatrixTimesMatrix
4270 //===----------------------------------------------------------------------===//
4271 
4273  auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
4274  auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
4275  auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4276 
4277  // left matrix columns' count and right matrix rows' count must be equal
4278  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4279  return emitError("left matrix columns' count must be equal to "
4280  "the right matrix rows' count");
4281 
4282  // right and result matrices columns' count must be the same
4283  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4284  return emitError(
4285  "right and result matrices must have equal columns' count");
4286 
4287  // right and result matrices component type must be the same
4288  if (rightMatrix.getElementType() != resultMatrix.getElementType())
4289  return emitError("right and result matrices' component type must"
4290  " be the same");
4291 
4292  // left and result matrices component type must be the same
4293  if (leftMatrix.getElementType() != resultMatrix.getElementType())
4294  return emitError("left and result matrices' component type"
4295  " must be the same");
4296 
4297  // left and result matrices rows count must be the same
4298  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4299  return emitError("left and result matrices must have equal rows' count");
4300 
4301  return success();
4302 }
4303 
4304 //===----------------------------------------------------------------------===//
4305 // spirv.SpecConstantComposite
4306 //===----------------------------------------------------------------------===//
4307 
4308 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4309  OperationState &result) {
4310 
4311  StringAttr compositeName;
4312  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4313  result.attributes))
4314  return failure();
4315 
4316  if (parser.parseLParen())
4317  return failure();
4318 
4319  SmallVector<Attribute, 4> constituents;
4320 
4321  do {
4322  // The name of the constituent attribute isn't important
4323  const char *attrName = "spec_const";
4324  FlatSymbolRefAttr specConstRef;
4325  NamedAttrList attrs;
4326 
4327  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4328  return failure();
4329 
4330  constituents.push_back(specConstRef);
4331  } while (!parser.parseOptionalComma());
4332 
4333  if (parser.parseRParen())
4334  return failure();
4335 
4337  parser.getBuilder().getArrayAttr(constituents));
4338 
4339  Type type;
4340  if (parser.parseColonType(type))
4341  return failure();
4342 
4343  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
4344 
4345  return success();
4346 }
4347 
4349  printer << " ";
4350  printer.printSymbolName(getSymName());
4351  printer << " (";
4352  auto constituents = this->getConstituents().getValue();
4353 
4354  if (!constituents.empty())
4355  llvm::interleaveComma(constituents, printer);
4356 
4357  printer << ") : " << getType();
4358 }
4359 
4361  auto cType = getType().dyn_cast<spirv::CompositeType>();
4362  auto constituents = this->getConstituents().getValue();
4363 
4364  if (!cType)
4365  return emitError("result type must be a composite type, but provided ")
4366  << getType();
4367 
4368  if (cType.isa<spirv::CooperativeMatrixNVType>())
4369  return emitError("unsupported composite type ") << cType;
4370  if (cType.isa<spirv::JointMatrixINTELType>())
4371  return emitError("unsupported composite type ") << cType;
4372  if (constituents.size() != cType.getNumElements())
4373  return emitError("has incorrect number of operands: expected ")
4374  << cType.getNumElements() << ", but provided "
4375  << constituents.size();
4376 
4377  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4378  auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4379 
4380  auto constituentSpecConstOp =
4381  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4382  (*this)->getParentOp(), constituent.getAttr()));
4383 
4384  if (constituentSpecConstOp.getDefaultValue().getType() !=
4385  cType.getElementType(index))
4386  return emitError("has incorrect types of operands: expected ")
4387  << cType.getElementType(index) << ", but provided "
4388  << constituentSpecConstOp.getDefaultValue().getType();
4389  }
4390 
4391  return success();
4392 }
4393 
4394 //===----------------------------------------------------------------------===//
4395 // spirv.SpecConstantOperation
4396 //===----------------------------------------------------------------------===//
4397 
4398 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4399  OperationState &result) {
4400  Region *body = result.addRegion();
4401 
4402  if (parser.parseKeyword("wraps"))
4403  return failure();
4404 
4405  body->push_back(new Block);
4406  Block &block = body->back();
4407  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4408 
4409  if (!wrappedOp)
4410  return failure();
4411 
4412  OpBuilder builder(parser.getContext());
4413  builder.setInsertionPointToEnd(&block);
4414  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4415  result.location = wrappedOp->getLoc();
4416 
4417  result.addTypes(wrappedOp->getResult(0).getType());
4418 
4419  if (parser.parseOptionalAttrDict(result.attributes))
4420  return failure();
4421 
4422  return success();
4423 }
4424 
4426  printer << " wraps ";
4427  printer.printGenericOp(&getBody().front().front());
4428 }
4429 
4430 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4431  Block &block = getRegion().getBlocks().front();
4432 
4433  if (block.getOperations().size() != 2)
4434  return emitOpError("expected exactly 2 nested ops");
4435 
4436  Operation &enclosedOp = block.getOperations().front();
4437 
4439  return emitOpError("invalid enclosed op");
4440 
4441  for (auto operand : enclosedOp.getOperands())
4442  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4443  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4444  return emitOpError(
4445  "invalid operand, must be defined by a constant operation");
4446 
4447  return success();
4448 }
4449 
4450 //===----------------------------------------------------------------------===//
4451 // spirv.GL.FrexpStruct
4452 //===----------------------------------------------------------------------===//
4453 
4455  spirv::StructType structTy =
4456  getResult().getType().dyn_cast<spirv::StructType>();
4457 
4458  if (structTy.getNumElements() != 2)
4459  return emitError("result type must be a struct type with two memebers");
4460 
4461  Type significandTy = structTy.getElementType(0);
4462  Type exponentTy = structTy.getElementType(1);
4463  VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4464  IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4465 
4466  Type operandTy = getOperand().getType();
4467  VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4468  FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4469 
4470  if (significandTy != operandTy)
4471  return emitError("member zero of the resulting struct type must be the "
4472  "same type as the operand");
4473 
4474  if (exponentVecTy) {
4475  IntegerType componentIntTy =
4476  exponentVecTy.getElementType().dyn_cast<IntegerType>();
4477  if (!componentIntTy || componentIntTy.getWidth() != 32)
4478  return emitError("member one of the resulting struct type must"
4479  "be a scalar or vector of 32 bit integer type");
4480  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4481  return emitError("member one of the resulting struct type "
4482  "must be a scalar or vector of 32 bit integer type");
4483  }
4484 
4485  // Check that the two member types have the same number of components
4486  if (operandVecTy && exponentVecTy &&
4487  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4488  return success();
4489 
4490  if (operandFTy && exponentIntTy)
4491  return success();
4492 
4493  return emitError("member one of the resulting struct type must have the same "
4494  "number of components as the operand type");
4495 }
4496 
4497 //===----------------------------------------------------------------------===//
4498 // spirv.GL.Ldexp
4499 //===----------------------------------------------------------------------===//
4500 
4502  Type significandType = getX().getType();
4503  Type exponentType = getExp().getType();
4504 
4505  if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4506  return emitOpError("operands must both be scalars or vectors");
4507 
4508  auto getNumElements = [](Type type) -> unsigned {
4509  if (auto vectorType = type.dyn_cast<VectorType>())
4510  return vectorType.getNumElements();
4511  return 1;
4512  };
4513 
4514  if (getNumElements(significandType) != getNumElements(exponentType))
4515  return emitOpError("operands must have the same number of elements");
4516 
4517  return success();
4518 }
4519 
4520 //===----------------------------------------------------------------------===//
4521 // spirv.ImageDrefGather
4522 //===----------------------------------------------------------------------===//
4523 
4525  VectorType resultType = getResult().getType().cast<VectorType>();
4526  auto sampledImageType =
4527  getSampledimage().getType().cast<spirv::SampledImageType>();
4528  auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4529 
4530  if (resultType.getNumElements() != 4)
4531  return emitOpError("result type must be a vector of four components");
4532 
4533  Type elementType = resultType.getElementType();
4534  Type sampledElementType = imageType.getElementType();
4535  if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4536  return emitOpError(
4537  "the component type of result must be the same as sampled type of the "
4538  "underlying image type");
4539 
4540  spirv::Dim imageDim = imageType.getDim();
4541  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4542 
4543  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4544  imageDim != spirv::Dim::Rect)
4545  return emitOpError(
4546  "the Dim operand of the underlying image type must be 2D, Cube, or "
4547  "Rect");
4548 
4549  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4550  return emitOpError("the MS operand of the underlying image type must be 0");
4551 
4552  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
4553  auto operandArguments = getOperandArguments();
4554 
4555  return verifyImageOperands(*this, attr, operandArguments);
4556 }
4557 
4558 //===----------------------------------------------------------------------===//
4559 // spirv.ShiftLeftLogicalOp
4560 //===----------------------------------------------------------------------===//
4561 
4563  return verifyShiftOp(*this);
4564 }
4565 
4566 //===----------------------------------------------------------------------===//
4567 // spirv.ShiftRightArithmeticOp
4568 //===----------------------------------------------------------------------===//
4569 
4571  return verifyShiftOp(*this);
4572 }
4573 
4574 //===----------------------------------------------------------------------===//
4575 // spirv.ShiftRightLogicalOp
4576 //===----------------------------------------------------------------------===//
4577 
4579  return verifyShiftOp(*this);
4580 }
4581 
4582 //===----------------------------------------------------------------------===//
4583 // spirv.ImageQuerySize
4584 //===----------------------------------------------------------------------===//
4585 
4587  spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
4588  Type resultType = getResult().getType();
4589 
4590  spirv::Dim dim = imageType.getDim();
4591  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4592  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4593  switch (dim) {
4594  case spirv::Dim::Dim1D:
4595  case spirv::Dim::Dim2D:
4596  case spirv::Dim::Dim3D:
4597  case spirv::Dim::Cube:
4598  if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4599  samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4600  samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4601  return emitError(
4602  "if Dim is 1D, 2D, 3D, or Cube, "
4603  "it must also have either an MS of 1 or a Sampled of 0 or 2");
4604  break;
4605  case spirv::Dim::Buffer:
4606  case spirv::Dim::Rect:
4607  break;
4608  default:
4609  return emitError("the Dim operand of the image type must "
4610  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4611  }
4612 
4613  unsigned componentNumber = 0;
4614  switch (dim) {
4615  case spirv::Dim::Dim1D:
4616  case spirv::Dim::Buffer:
4617  componentNumber = 1;
4618  break;
4619  case spirv::Dim::Dim2D:
4620  case spirv::Dim::Cube:
4621  case spirv::Dim::Rect:
4622  componentNumber = 2;
4623  break;
4624  case spirv::Dim::Dim3D:
4625  componentNumber = 3;
4626  break;
4627  default:
4628  break;
4629  }
4630 
4631  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4632  componentNumber += 1;
4633 
4634  unsigned resultComponentNumber = 1;
4635  if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4636  resultComponentNumber = resultVectorType.getNumElements();
4637 
4638  if (componentNumber != resultComponentNumber)
4639  return emitError("expected the result to have ")
4640  << componentNumber << " component(s), but found "
4641  << resultComponentNumber << " component(s)";
4642 
4643  return success();
4644 }
4645 
4646 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4647  OpAsmParser &parser,
4648  OperationState &state) {
4651  Type type;
4652  auto loc = parser.getCurrentLocation();
4653  SmallVector<Type, 4> indicesTypes;
4654 
4655  if (parser.parseOperand(ptrInfo) ||
4656  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4657  parser.parseColonType(type) ||
4658  parser.resolveOperand(ptrInfo, type, state.operands))
4659  return failure();
4660 
4661  // Check that the provided indices list is not empty before parsing their
4662  // type list.
4663  if (indicesInfo.empty())
4664  return emitError(state.location) << opName << " expected element";
4665 
4666  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4667  return failure();
4668 
4669  // Check that the indices types list is not empty and that it has a one-to-one
4670  // mapping to the provided indices.
4671  if (indicesTypes.size() != indicesInfo.size())
4672  return emitError(state.location)
4673  << opName
4674  << " indices types' count must be equal to indices info count";
4675 
4676  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4677  return failure();
4678 
4679  auto resultType = getElementPtrType(
4680  type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
4681  if (!resultType)
4682  return failure();
4683 
4684  state.addTypes(resultType);
4685  return success();
4686 }
4687 
4688 template <typename Op>
4689 static auto concatElemAndIndices(Op op) {
4690  SmallVector<Value> ret(op.getIndices().size() + 1);
4691  ret[0] = op.getElement();
4692  llvm::copy(op.getIndices(), ret.begin() + 1);
4693  return ret;
4694 }
4695 
4696 //===----------------------------------------------------------------------===//
4697 // spirv.InBoundsPtrAccessChainOp
4698 //===----------------------------------------------------------------------===//
4699 
4700 void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4701  OperationState &state,
4702  Value basePtr, Value element,
4703  ValueRange indices) {
4704  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4705  assert(type && "Unable to deduce return type based on basePtr and indices");
4706  build(builder, state, type, basePtr, element, indices);
4707 }
4708 
4709 ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4710  OperationState &result) {
4712  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4713 }
4714 
4716  printAccessChain(*this, concatElemAndIndices(*this), printer);
4717 }
4718 
4720  return verifyAccessChain(*this, getIndices());
4721 }
4722 
4723 //===----------------------------------------------------------------------===//
4724 // spirv.PtrAccessChainOp
4725 //===----------------------------------------------------------------------===//
4726 
4727 void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4728  Value basePtr, Value element,
4729  ValueRange indices) {
4730  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4731  assert(type && "Unable to deduce return type based on basePtr and indices");
4732  build(builder, state, type, basePtr, element, indices);
4733 }
4734 
4735 ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4736  OperationState &result) {
4737  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4738  parser, result);
4739 }
4740 
4742  printAccessChain(*this, concatElemAndIndices(*this), printer);
4743 }
4744 
4746  return verifyAccessChain(*this, getIndices());
4747 }
4748 
4749 //===----------------------------------------------------------------------===//
4750 // spirv.VectorTimesScalarOp
4751 //===----------------------------------------------------------------------===//
4752 
4754  if (getVector().getType() != getType())
4755  return emitOpError("vector operand and result type mismatch");
4756  auto scalarType = getType().cast<VectorType>().getElementType();
4757  if (getScalar().getType() != scalarType)
4758  return emitOpError("scalar operand and result element type match");
4759  return success();
4760 }
4761 
4762 //===----------------------------------------------------------------------===//
4763 // Group ops
4764 //===----------------------------------------------------------------------===//
4765 
4766 template <typename Op>
4768  spirv::Scope scope = op.getExecutionScope();
4769  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
4770  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
4771 
4772  return success();
4773 }
4774 
4776 
4778 
4780 
4782 
4784 
4786 
4788 
4790 
4792 
4794 
4795 //===----------------------------------------------------------------------===//
4796 // Integer Dot Product ops
4797 //===----------------------------------------------------------------------===//
4798 
4800  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
4801  "Not an integer dot product op?");
4802  assert(op->getNumResults() == 1 && "Expected a single result");
4803 
4804  Type factorTy = op->getOperand(0).getType();
4805  if (op->getOperand(1).getType() != factorTy)
4806  return op->emitOpError("requires the same type for both vector operands");
4807 
4808  unsigned expectedNumAttrs = 0;
4809  if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
4810  ++expectedNumAttrs;
4811  auto packedVectorFormat =
4813  .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
4814  if (!packedVectorFormat)
4815  return op->emitOpError("requires Packed Vector Format attribute for "
4816  "integer vector operands");
4817 
4818  assert(packedVectorFormat.getValue() ==
4819  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
4820  "Unknown Packed Vector Format");
4821  if (intTy.getWidth() != 32)
4822  return op->emitOpError(
4823  llvm::formatv("with specified Packed Vector Format ({0}) requires "
4824  "integer vector operands to be 32-bits wide",
4825  packedVectorFormat.getValue()));
4826  } else {
4828  return op->emitOpError(llvm::formatv(
4829  "with invalid format attribute for vector operands of type '{0}'",
4830  factorTy));
4831  }
4832 
4833  if (op->getAttrs().size() > expectedNumAttrs)
4834  return op->emitError(
4835  "op only supports the 'format' #spirv.packed_vector_format attribute");
4836 
4837  Type resultTy = op->getResultTypes().front();
4838  bool hasAccumulator = op->getNumOperands() == 3;
4839  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
4840  return op->emitOpError(
4841  "requires the same accumulator operand and result types");
4842 
4843  unsigned factorBitWidth = getBitWidth(factorTy);
4844  unsigned resultBitWidth = getBitWidth(resultTy);
4845  if (factorBitWidth > resultBitWidth)
4846  return op->emitOpError(
4847  llvm::formatv("result type has insufficient bit-width ({0} bits) "
4848  "for the specified vector operand type ({1} bits)",
4849  resultBitWidth, factorBitWidth));
4850 
4851  return success();
4852 }
4853 
4854 static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
4855  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
4856 }
4857 
4858 static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
4859  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
4860 }
4861 
4864  // Requires the SPV_KHR_integer_dot_product extension, specified either
4865  // explicitly or implied by target env's SPIR-V version >= 1.6.
4866  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
4867  return {extension};
4868 }
4869 
4872  // Requires the the DotProduct capability and capabilities that depend on
4873  // exact op types.
4874  static const auto dotProductCap = spirv::Capability::DotProduct;
4875  static const auto dotProductInput4x8BitPackedCap =
4876  spirv::Capability::DotProductInput4x8BitPacked;
4877  static const auto dotProductInput4x8BitCap =
4878  spirv::Capability::DotProductInput4x8Bit;
4879  static const auto dotProductInputAllCap =
4880  spirv::Capability::DotProductInputAll;
4881 
4882  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
4883 
4884  Type factorTy = op->getOperand(0).getType();
4885  if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
4886  auto formatAttr = op->getAttr(kPackedVectorFormatAttrName)
4887  .cast<spirv::PackedVectorFormatAttr>();
4888  if (formatAttr.getValue() ==
4889  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
4890  capabilities.push_back(dotProductInput4x8BitPackedCap);
4891 
4892  return capabilities;
4893  }
4894 
4895  auto vecTy = factorTy.cast<VectorType>();
4896  if (vecTy.getElementTypeBitWidth() == 8) {
4897  capabilities.push_back(dotProductInput4x8BitCap);
4898  return capabilities;
4899  }
4900 
4901  capabilities.push_back(dotProductInputAllCap);
4902  return capabilities;
4903 }
4904 
4905 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
4906  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
4907  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
4908  return getIntegerDotProductExtensions(); \
4909  } \
4910  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
4911  return getIntegerDotProductCapabilities(*this); \
4912  } \
4913  std::optional<spirv::Version> OpName::getMinVersion() { \
4914  return getIntegerDotProductMinVersion(); \
4915  } \
4916  std::optional<spirv::Version> OpName::getMaxVersion() { \
4917  return getIntegerDotProductMaxVersion(); \
4918  }
4919 
4920 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
4921 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
4922 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
4923 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
4924 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
4925 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
4926 
4927 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
4928 
4929 // TableGen'erated operation interfaces for querying versions, extensions, and
4930 // capabilities.
4931 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4932 
4933 // TablenGen'erated operation definitions.
4934 #define GET_OP_CLASSES
4935 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4936 
4937 namespace mlir {
4938 namespace spirv {
4939 // TableGen'erated operation availability interface implementations.
4940 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4941 } // namespace spirv
4942 } // namespace mlir
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
@ None
Operation::operand_range getIndices(Operation *op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:785
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:903
constexpr char kValuesAttrName[]
Definition: SPIRVOps.cpp:65
static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix)
Definition: SPIRVOps.cpp:4062
static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:611
constexpr char kClusterSize[]
Definition: SPIRVOps.cpp:45
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition: SPIRVOps.cpp:958
constexpr char kValueAttrName[]
Definition: SPIRVOps.cpp:64
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:1965
constexpr char kEqualSemanticsAttrName[]
Definition: SPIRVOps.cpp:48
static LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:560
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition: SPIRVOps.cpp:422
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue)
Definition: SPIRVOps.cpp:823
constexpr char kExecutionScopeAttrName[]
Definition: SPIRVOps.cpp:49
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:648
static unsigned getBitWidth(Type type)
Definition: SPIRVOps.cpp:676
constexpr char kTypeAttrName[]
Definition: SPIRVOps.cpp:62
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
Definition: SPIRVOps.cpp:765
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:72
constexpr char kUnequalSemanticsAttrName[]
Definition: SPIRVOps.cpp:63
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: SPIRVOps.cpp:1130
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
Definition: SPIRVOps.cpp:4854
constexpr char kIndicesAttrName[]
Definition: SPIRVOps.cpp:52
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1178
static ParseResult parseImageOperands(OpAsmParser &parser, spirv::ImageOperandsAttr &attr)
Definition: SPIRVOps.cpp:371
static StringRef stringifyTypeName()
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:771
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
Definition: SPIRVOps.cpp:2757
static LogicalResult verifyAtomicUpdateOp(Operation *op)
Definition: SPIRVOps.cpp:879
constexpr char kBranchWeightAttrName[]
Definition: SPIRVOps.cpp:43
constexpr char kCompositeSpecConstituentsName[]
Definition: SPIRVOps.cpp:66
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:285
constexpr char kInterfaceAttrName[]
Definition: SPIRVOps.cpp:54
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:584
constexpr char kSourceAlignmentAttrName[]
Definition: SPIRVOps.cpp:59
constexpr char kCallee[]
Definition: SPIRVOps.cpp:44
constexpr char kMemoryScopeAttrName[]
Definition: SPIRVOps.cpp:56
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix)
Definition: SPIRVOps.cpp:3957
static bool isDirectInModuleLikeOp(Operation *op)
Returns true if the given op is an module-like op that maintains a symbol table.
Definition: SPIRVOps.cpp:134
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:998
constexpr char kGroupOperationAttrName[]
Definition: SPIRVOps.cpp:51
static Type getUnaryOpResultType(Type operandType)
Result of a logical op must be a scalar or vector of boolean type.
Definition: SPIRVOps.cpp:990
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access attributes attached to a memory access operand/pointer.
Definition: SPIRVOps.cpp:254
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:1012
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op)
Definition: SPIRVOps.cpp:4102
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Definition: SPIRVOps.cpp:4905
constexpr char kMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:55
constexpr char kControl[]
Definition: SPIRVOps.cpp:46
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: SPIRVOps.cpp:313
constexpr char kDefaultValueAttrName[]
Definition: SPIRVOps.cpp:47
static LogicalResult verifyGroupOp(Op op)
Definition: SPIRVOps.cpp:4767
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:138
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:854
static LogicalResult verifyIntegerDotProduct(Operation *op)
Definition: SPIRVOps.cpp:4799
StringRef stringifyTypeName< FloatType >()
Definition: SPIRVOps.cpp:873
static auto concatElemAndIndices(Op op)
Definition: SPIRVOps.cpp:4689
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:599
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:939
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1120
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
Definition: SPIRVOps.cpp:3182
static bool isNestedInFunctionOpInterface(Operation *op)
Returns true if the given op is a function-like op or nested in a function-like op without a module-l...
Definition: SPIRVOps.cpp:122
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:101
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
Definition: SPIRVOps.cpp:4863
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, spirv::ImageOperandsAttr attr)
Definition: SPIRVOps.cpp:386
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:4646
constexpr char kSourceMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:60
constexpr char kSpecIdAttrName[]
Definition: SPIRVOps.cpp:61
constexpr char kAlignmentAttrName[]
Definition: SPIRVOps.cpp:42
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
Definition: SPIRVOps.cpp:4871
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
Definition: SPIRVOps.cpp:4858
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:520
static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef< Ty > enumValues, function_ref< StringRef(Ty)> stringifyFn)
Definition: SPIRVOps.cpp:159
constexpr char kPackedVectorFormatAttrName[]
Definition: SPIRVOps.cpp:57
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1185
StringRef stringifyTypeName< IntegerType >()
Definition: SPIRVOps.cpp:868
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
Definition: SPIRVOps.cpp:343
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:395
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:476
constexpr char kSemanticsAttrName[]
Definition: SPIRVOps.cpp:58
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp)
Definition: SPIRVOps.cpp:1218
static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
Definition: SPIRVOps.cpp:176
static LogicalResult verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op)
Definition: SPIRVOps.cpp:4026
constexpr char kInitializerAttrName[]
Definition: SPIRVOps.cpp:53
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
Definition: SPIRVOps.cpp:231
constexpr char kFnNameAttrName[]
Definition: SPIRVOps.cpp:50
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:809
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1181
static bool isZero(OpFoldResult v)
Definition: Tiling.cpp:47
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:171
U dyn_cast() const
Definition: Attributes.h:166
U cast() const
Definition: Attributes.h:176
bool isa() const
Casting utility functions.
Definition: Attributes.h:156
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumSuccessors()
Definition: Block.cpp:238
bool empty()
Definition: Block.h:137
Operation & back()
Definition: Block.h:141
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
OpListType & getOperations()
Definition: Block.h:126
Operation & front()
Definition: Block.h:142
iterator end()
Definition: Block.h:133
iterator begin()
Definition: Block.h:132
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:202
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:165
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:269
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:247
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:93
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:84
NoneType getNoneType()
Definition: Builders.cpp:101
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:113
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:255
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:70
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:259
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:300
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:99
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:240
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:417
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:405
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:641
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:647
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:109
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:400
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Value getOperand(unsigned idx)
Definition: Operation.h:329
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:592
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:204
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:437
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:433
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:447
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
unsigned getNumOperands()
Definition: Operation.h:325
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:218
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:418
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:550
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
operand_type_range getOperandTypes()
Definition: Operation.h:376
result_type_range getResultTypes()
Definition: Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:520
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:38
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:48
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class models how operands are forwarded to block arguments in control flow.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:80
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
U dyn_cast() const
Definition: Types.h:311
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:108
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
type_range getTypes() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
Type getElementType() const
Definition: SPIRVTypes.cpp:65
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:63
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:126
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:243
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:245
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:241
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:423
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:431
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:427
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:307
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:311
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:309
unsigned getNumColumns() const
Returns the number of columns.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:482
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:484
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:478
SPIR-V struct type.
Definition: SPIRVTypes.h:282
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
Type getFunctionType(Builder &builder, ArrayRef< OpAsmParser::Argument > argAttrs, ArrayRef< Type > resultTypes)
Get a function type corresponding to an array of arguments (which have types) and a set of result typ...
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.