MLIR  16.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/ADT/bit.h"
35 #include <cassert>
36 #include <numeric>
37 
38 using namespace mlir;
39 
40 // TODO: generate these strings using ODS.
41 constexpr char kMemoryAccessAttrName[] = "memory_access";
42 constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
43 constexpr char kAlignmentAttrName[] = "alignment";
44 constexpr char kSourceAlignmentAttrName[] = "source_alignment";
45 constexpr char kBranchWeightAttrName[] = "branch_weights";
46 constexpr char kCallee[] = "callee";
47 constexpr char kClusterSize[] = "cluster_size";
48 constexpr char kControl[] = "control";
49 constexpr char kDefaultValueAttrName[] = "default_value";
50 constexpr char kExecutionScopeAttrName[] = "execution_scope";
51 constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
52 constexpr char kFnNameAttrName[] = "fn";
53 constexpr char kGroupOperationAttrName[] = "group_operation";
54 constexpr char kIndicesAttrName[] = "indices";
55 constexpr char kInitializerAttrName[] = "initializer";
56 constexpr char kInterfaceAttrName[] = "interface";
57 constexpr char kMemoryScopeAttrName[] = "memory_scope";
58 constexpr char kSemanticsAttrName[] = "semantics";
59 constexpr char kSpecIdAttrName[] = "spec_id";
60 constexpr char kTypeAttrName[] = "type";
61 constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
62 constexpr char kValueAttrName[] = "value";
63 constexpr char kValuesAttrName[] = "values";
64 constexpr char kCompositeSpecConstituentsName[] = "constituents";
65 
66 //===----------------------------------------------------------------------===//
67 // Common utility functions
68 //===----------------------------------------------------------------------===//
69 
71  OperationState &result) {
73  Type type;
74  // If the operand list is in-between parentheses, then we have a generic form.
75  // (see the fallback in `printOneResultOp`).
76  SMLoc loc = parser.getCurrentLocation();
77  if (!parser.parseOptionalLParen()) {
78  if (parser.parseOperandList(ops) || parser.parseRParen() ||
79  parser.parseOptionalAttrDict(result.attributes) ||
80  parser.parseColon() || parser.parseType(type))
81  return failure();
82  auto fnType = type.dyn_cast<FunctionType>();
83  if (!fnType) {
84  parser.emitError(loc, "expected function type");
85  return failure();
86  }
87  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
88  return failure();
89  result.addTypes(fnType.getResults());
90  return success();
91  }
92  return failure(parser.parseOperandList(ops) ||
93  parser.parseOptionalAttrDict(result.attributes) ||
94  parser.parseColonType(type) ||
95  parser.resolveOperands(ops, type, result.operands) ||
96  parser.addTypeToList(type, result.types));
97 }
98 
99 static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
100  assert(op->getNumResults() == 1 && "op should have one result");
101 
102  // If not all the operand and result types are the same, just use the
103  // generic assembly form to avoid omitting information in printing.
104  auto resultType = op->getResult(0).getType();
105  if (llvm::any_of(op->getOperandTypes(),
106  [&](Type type) { return type != resultType; })) {
107  p.printGenericOp(op, /*printOpName=*/false);
108  return;
109  }
110 
111  p << ' ';
112  p.printOperands(op->getOperands());
114  // Now we can output only one type for all operands and the result.
115  p << " : " << resultType;
116 }
117 
118 /// Returns true if the given op is a function-like op or nested in a
119 /// function-like op without a module-like op in the middle.
121  if (!op)
122  return false;
123  if (op->hasTrait<OpTrait::SymbolTable>())
124  return false;
125  if (isa<FunctionOpInterface>(op))
126  return true;
128 }
129 
130 /// Returns true if the given op is an module-like op that maintains a symbol
131 /// table.
133  return op && op->hasTrait<OpTrait::SymbolTable>();
134 }
135 
137  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
138  if (!constOp) {
139  return failure();
140  }
141  auto valueAttr = constOp.getValue();
142  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
143  if (!integerValueAttr) {
144  return failure();
145  }
146 
147  if (integerValueAttr.getType().isSignlessInteger())
148  value = integerValueAttr.getInt();
149  else
150  value = integerValueAttr.getSInt();
151 
152  return success();
153 }
154 
155 template <typename Ty>
156 static ArrayAttr
158  function_ref<StringRef(Ty)> stringifyFn) {
159  if (enumValues.empty()) {
160  return nullptr;
161  }
162  SmallVector<StringRef, 1> enumValStrs;
163  enumValStrs.reserve(enumValues.size());
164  for (auto val : enumValues) {
165  enumValStrs.emplace_back(stringifyFn(val));
166  }
167  return builder.getStrArrayAttr(enumValStrs);
168 }
169 
170 /// Parses the next string attribute in `parser` as an enumerant of the given
171 /// `EnumClass`.
172 template <typename EnumClass>
173 static ParseResult
174 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
175  StringRef attrName = spirv::attributeName<EnumClass>()) {
176  Attribute attrVal;
177  NamedAttrList attr;
178  auto loc = parser.getCurrentLocation();
179  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
180  attrName, attr))
181  return failure();
182  if (!attrVal.isa<StringAttr>())
183  return parser.emitError(loc, "expected ")
184  << attrName << " attribute specified as string";
185  auto attrOptional =
186  spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
187  if (!attrOptional)
188  return parser.emitError(loc, "invalid ")
189  << attrName << " attribute specification: " << attrVal;
190  value = *attrOptional;
191  return success();
192 }
193 
194 /// Parses the next string attribute in `parser` as an enumerant of the given
195 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
196 /// attribute with the enum class's name as attribute name.
197 template <typename EnumAttrClass,
198  typename EnumClass = typename EnumAttrClass::ValueType>
199 static ParseResult
200 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
201  StringRef attrName = spirv::attributeName<EnumClass>()) {
202  if (parseEnumStrAttr(value, parser))
203  return failure();
204  state.addAttribute(attrName,
205  parser.getBuilder().getAttr<EnumAttrClass>(value));
206  return success();
207 }
208 
209 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
210 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
211 /// the enum class's name as attribute name.
212 template <typename EnumAttrClass,
213  typename EnumClass = typename EnumAttrClass::ValueType>
214 static ParseResult
216  OperationState &state,
217  StringRef attrName = spirv::attributeName<EnumClass>()) {
218  if (parseEnumKeywordAttr(value, parser))
219  return failure();
220  state.addAttribute(attrName,
221  parser.getBuilder().getAttr<EnumAttrClass>(value));
222  return success();
223 }
224 
225 /// Parses Function, Selection and Loop control attributes. If no control is
226 /// specified, "None" is used as a default.
227 template <typename EnumAttrClass, typename EnumClass>
228 static ParseResult
230  StringRef attrName = spirv::attributeName<EnumClass>()) {
231  if (succeeded(parser.parseOptionalKeyword(kControl))) {
232  EnumClass control;
233  if (parser.parseLParen() ||
234  parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
235  parser.parseRParen())
236  return failure();
237  return success();
238  }
239  // Set control to "None" otherwise.
240  Builder builder = parser.getBuilder();
241  state.addAttribute(attrName,
242  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
243  return success();
244 }
245 
246 /// Parses optional memory access attributes attached to a memory access
247 /// operand/pointer. Specifically, parses the following syntax:
248 /// (`[` memory-access `]`)?
249 /// where:
250 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
251 /// integer-literal | `"NonTemporal"`
253  OperationState &state) {
254  // Parse an optional list of attributes staring with '['
255  if (parser.parseOptionalLSquare()) {
256  // Nothing to do
257  return success();
258  }
259 
260  spirv::MemoryAccess memoryAccessAttr;
261  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
263  return failure();
264 
265  if (spirv::bitEnumContainsAll(memoryAccessAttr,
266  spirv::MemoryAccess::Aligned)) {
267  // Parse integer attribute for alignment.
268  Attribute alignmentAttr;
269  Type i32Type = parser.getBuilder().getIntegerType(32);
270  if (parser.parseComma() ||
271  parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
272  state.attributes)) {
273  return failure();
274  }
275  }
276  return parser.parseRSquare();
277 }
278 
279 // TODO Make sure to merge this and the previous function into one template
280 // parameterized by memory access attribute name and alignment. Doing so now
281 // results in VS2017 in producing an internal error (at the call site) that's
282 // not detailed enough to understand what is happening.
284  OperationState &state) {
285  // Parse an optional list of attributes staring with '['
286  if (parser.parseOptionalLSquare()) {
287  // Nothing to do
288  return success();
289  }
290 
291  spirv::MemoryAccess memoryAccessAttr;
292  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
294  return failure();
295 
296  if (spirv::bitEnumContainsAll(memoryAccessAttr,
297  spirv::MemoryAccess::Aligned)) {
298  // Parse integer attribute for alignment.
299  Attribute alignmentAttr;
300  Type i32Type = parser.getBuilder().getIntegerType(32);
301  if (parser.parseComma() ||
302  parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
303  state.attributes)) {
304  return failure();
305  }
306  }
307  return parser.parseRSquare();
308 }
309 
310 template <typename MemoryOpTy>
312  MemoryOpTy memoryOp, OpAsmPrinter &printer,
313  SmallVectorImpl<StringRef> &elidedAttrs,
314  Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
315  Optional<uint32_t> alignmentAttrValue = None) {
316  // Print optional memory access attribute.
317  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
318  : memoryOp.getMemoryAccess())) {
319  elidedAttrs.push_back(kMemoryAccessAttrName);
320 
321  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
322 
323  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
324  // Print integer alignment attribute.
325  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
326  : memoryOp.getAlignment())) {
327  elidedAttrs.push_back(kAlignmentAttrName);
328  printer << ", " << alignment;
329  }
330  }
331  printer << "]";
332  }
333  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
334 }
335 
336 // TODO Make sure to merge this and the previous function into one template
337 // parameterized by memory access attribute name and alignment. Doing so now
338 // results in VS2017 in producing an internal error (at the call site) that's
339 // not detailed enough to understand what is happening.
340 template <typename MemoryOpTy>
342  MemoryOpTy memoryOp, OpAsmPrinter &printer,
343  SmallVectorImpl<StringRef> &elidedAttrs,
344  Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
345  Optional<uint32_t> alignmentAttrValue = None) {
346 
347  printer << ", ";
348 
349  // Print optional memory access attribute.
350  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
351  : memoryOp.getMemoryAccess())) {
352  elidedAttrs.push_back(kSourceMemoryAccessAttrName);
353 
354  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
355 
356  if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
357  // Print integer alignment attribute.
358  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
359  : memoryOp.getAlignment())) {
360  elidedAttrs.push_back(kSourceAlignmentAttrName);
361  printer << ", " << alignment;
362  }
363  }
364  printer << "]";
365  }
366  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
367 }
368 
370  spirv::ImageOperandsAttr &attr) {
371  // Expect image operands
372  if (parser.parseOptionalLSquare())
373  return success();
374 
375  spirv::ImageOperands imageOperands;
376  if (parseEnumStrAttr(imageOperands, parser))
377  return failure();
378 
379  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
380 
381  return parser.parseRSquare();
382 }
383 
384 static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
385  spirv::ImageOperandsAttr attr) {
386  if (attr) {
387  auto strImageOperands = stringifyImageOperands(attr.getValue());
388  printer << "[\"" << strImageOperands << "\"]";
389  }
390 }
391 
392 template <typename Op>
394  spirv::ImageOperandsAttr attr,
395  Operation::operand_range operands) {
396  if (!attr) {
397  if (operands.empty())
398  return success();
399 
400  return imageOp.emitError("the Image Operands should encode what operands "
401  "follow, as per Image Operands");
402  }
403 
404  // TODO: Add the validation rules for the following Image Operands.
405  spirv::ImageOperands noSupportOperands =
406  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
407  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
408  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
409  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
410  spirv::ImageOperands::MakeTexelAvailable |
411  spirv::ImageOperands::MakeTexelVisible |
412  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
413 
414  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
415  llvm_unreachable("unimplemented operands of Image Operands");
416 
417  return success();
418 }
419 
421  bool requireSameBitWidth = true,
422  bool skipBitWidthCheck = false) {
423  // Some CastOps have no limit on bit widths for result and operand type.
424  if (skipBitWidthCheck)
425  return success();
426 
427  Type operandType = op->getOperand(0).getType();
428  Type resultType = op->getResult(0).getType();
429 
430  // ODS checks that result type and operand type have the same shape.
431  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
432  operandType = vectorType.getElementType();
433  resultType = resultType.cast<VectorType>().getElementType();
434  }
435 
436  if (auto coopMatrixType =
437  operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
438  operandType = coopMatrixType.getElementType();
439  resultType =
441  }
442 
443  if (auto jointMatrixType =
444  operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
445  operandType = jointMatrixType.getElementType();
446  resultType =
448  }
449 
450  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
451  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
452  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
453 
454  if (requireSameBitWidth) {
455  if (!isSameBitWidth) {
456  return op->emitOpError(
457  "expected the same bit widths for operand type and result "
458  "type, but provided ")
459  << operandType << " and " << resultType;
460  }
461  return success();
462  }
463 
464  if (isSameBitWidth) {
465  return op->emitOpError(
466  "expected the different bit widths for operand type and result "
467  "type, but provided ")
468  << operandType << " and " << resultType;
469  }
470  return success();
471 }
472 
473 template <typename MemoryOpTy>
474 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
475  // ODS checks for attributes values. Just need to verify that if the
476  // memory-access attribute is Aligned, then the alignment attribute must be
477  // present.
478  auto *op = memoryOp.getOperation();
479  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
480  if (!memAccessAttr) {
481  // Alignment attribute shouldn't be present if memory access attribute is
482  // not present.
483  if (op->getAttr(kAlignmentAttrName)) {
484  return memoryOp.emitOpError(
485  "invalid alignment specification without aligned memory access "
486  "specification");
487  }
488  return success();
489  }
490 
491  auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
492 
493  if (!memAccess) {
494  return memoryOp.emitOpError("invalid memory access specifier: ")
495  << memAccessAttr;
496  }
497 
498  if (spirv::bitEnumContainsAll(memAccess.getValue(),
499  spirv::MemoryAccess::Aligned)) {
500  if (!op->getAttr(kAlignmentAttrName)) {
501  return memoryOp.emitOpError("missing alignment value");
502  }
503  } else {
504  if (op->getAttr(kAlignmentAttrName)) {
505  return memoryOp.emitOpError(
506  "invalid alignment specification with non-aligned memory access "
507  "specification");
508  }
509  }
510  return success();
511 }
512 
513 // TODO Make sure to merge this and the previous function into one template
514 // parameterized by memory access attribute name and alignment. Doing so now
515 // results in VS2017 in producing an internal error (at the call site) that's
516 // not detailed enough to understand what is happening.
517 template <typename MemoryOpTy>
519  // ODS checks for attributes values. Just need to verify that if the
520  // memory-access attribute is Aligned, then the alignment attribute must be
521  // present.
522  auto *op = memoryOp.getOperation();
523  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
524  if (!memAccessAttr) {
525  // Alignment attribute shouldn't be present if memory access attribute is
526  // not present.
527  if (op->getAttr(kSourceAlignmentAttrName)) {
528  return memoryOp.emitOpError(
529  "invalid alignment specification without aligned memory access "
530  "specification");
531  }
532  return success();
533  }
534 
535  auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
536 
537  if (!memAccess) {
538  return memoryOp.emitOpError("invalid memory access specifier: ")
539  << memAccess;
540  }
541 
542  if (spirv::bitEnumContainsAll(memAccess.getValue(),
543  spirv::MemoryAccess::Aligned)) {
544  if (!op->getAttr(kSourceAlignmentAttrName)) {
545  return memoryOp.emitOpError("missing alignment value");
546  }
547  } else {
548  if (op->getAttr(kSourceAlignmentAttrName)) {
549  return memoryOp.emitOpError(
550  "invalid alignment specification with non-aligned memory access "
551  "specification");
552  }
553  }
554  return success();
555 }
556 
557 static LogicalResult
558 verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
559  // According to the SPIR-V specification:
560  // "Despite being a mask and allowing multiple bits to be combined, it is
561  // invalid for more than one of these four bits to be set: Acquire, Release,
562  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
563  // Release semantics is done by setting the AcquireRelease bit, not by setting
564  // two bits."
565  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
566  spirv::MemorySemantics::Release |
567  spirv::MemorySemantics::AcquireRelease |
568  spirv::MemorySemantics::SequentiallyConsistent;
569 
570  auto bitCount = llvm::countPopulation(
571  static_cast<uint32_t>(memorySemantics & atMostOneInSet));
572  if (bitCount > 1) {
573  return op->emitError(
574  "expected at most one of these four memory constraints "
575  "to be set: `Acquire`, `Release`,"
576  "`AcquireRelease` or `SequentiallyConsistent`");
577  }
578  return success();
579 }
580 
581 template <typename LoadStoreOpTy>
582 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
583  Value val) {
584  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
585  // type of the pointer and the type of the value are the same
586  //
587  // TODO: Check that the value type satisfies restrictions of
588  // SPIR-V OpLoad/OpStore operations
589  if (val.getType() !=
591  return op.emitOpError("mismatch in result type and pointer type");
592  }
593  return success();
594 }
595 
596 template <typename BlockReadWriteOpTy>
598  Value ptr, Value val) {
599  auto valType = val.getType();
600  if (auto valVecTy = valType.dyn_cast<VectorType>())
601  valType = valVecTy.getElementType();
602 
603  if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
604  return op.emitOpError("mismatch in result type and pointer type");
605  }
606  return success();
607 }
608 
610  OperationState &state) {
611  auto builtInName = llvm::convertToSnakeFromCamelCase(
612  stringifyDecoration(spirv::Decoration::BuiltIn));
613  if (succeeded(parser.parseOptionalKeyword("bind"))) {
614  Attribute set, binding;
615  // Parse optional descriptor binding
616  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
617  stringifyDecoration(spirv::Decoration::DescriptorSet));
618  auto bindingName = llvm::convertToSnakeFromCamelCase(
619  stringifyDecoration(spirv::Decoration::Binding));
620  Type i32Type = parser.getBuilder().getIntegerType(32);
621  if (parser.parseLParen() ||
622  parser.parseAttribute(set, i32Type, descriptorSetName,
623  state.attributes) ||
624  parser.parseComma() ||
625  parser.parseAttribute(binding, i32Type, bindingName,
626  state.attributes) ||
627  parser.parseRParen()) {
628  return failure();
629  }
630  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
631  StringAttr builtIn;
632  if (parser.parseLParen() ||
633  parser.parseAttribute(builtIn, builtInName, state.attributes) ||
634  parser.parseRParen()) {
635  return failure();
636  }
637  }
638 
639  // Parse other attributes
640  if (parser.parseOptionalAttrDict(state.attributes))
641  return failure();
642 
643  return success();
644 }
645 
647  SmallVectorImpl<StringRef> &elidedAttrs) {
648  // Print optional descriptor binding
649  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
650  stringifyDecoration(spirv::Decoration::DescriptorSet));
651  auto bindingName = llvm::convertToSnakeFromCamelCase(
652  stringifyDecoration(spirv::Decoration::Binding));
653  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
654  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
655  if (descriptorSet && binding) {
656  elidedAttrs.push_back(descriptorSetName);
657  elidedAttrs.push_back(bindingName);
658  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
659  << ")";
660  }
661 
662  // Print BuiltIn attribute if present
663  auto builtInName = llvm::convertToSnakeFromCamelCase(
664  stringifyDecoration(spirv::Decoration::BuiltIn));
665  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
666  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
667  elidedAttrs.push_back(builtInName);
668  }
669 
670  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
671 }
672 
673 // Get bit width of types.
674 static unsigned getBitWidth(Type type) {
675  if (type.isa<spirv::PointerType>()) {
676  // Just return 64 bits for pointer types for now.
677  // TODO: Make sure not caller relies on the actual pointer width value.
678  return 64;
679  }
680 
681  if (type.isIntOrFloat())
682  return type.getIntOrFloatBitWidth();
683 
684  if (auto vectorType = type.dyn_cast<VectorType>()) {
685  assert(vectorType.getElementType().isIntOrFloat());
686  return vectorType.getNumElements() *
687  vectorType.getElementType().getIntOrFloatBitWidth();
688  }
689  llvm_unreachable("unhandled bit width computation for type");
690 }
691 
692 /// Walks the given type hierarchy with the given indices, potentially down
693 /// to component granularity, to select an element type. Returns null type and
694 /// emits errors with the given loc on failure.
695 static Type
697  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
698  if (indices.empty()) {
699  emitErrorFn("expected at least one index for spirv.CompositeExtract");
700  return nullptr;
701  }
702 
703  for (auto index : indices) {
704  if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
705  if (cType.hasCompileTimeKnownNumElements() &&
706  (index < 0 ||
707  static_cast<uint64_t>(index) >= cType.getNumElements())) {
708  emitErrorFn("index ") << index << " out of bounds for " << type;
709  return nullptr;
710  }
711  type = cType.getElementType(index);
712  } else {
713  emitErrorFn("cannot extract from non-composite type ")
714  << type << " with index " << index;
715  return nullptr;
716  }
717  }
718  return type;
719 }
720 
721 static Type
723  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
724  auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
725  if (!indicesArrayAttr) {
726  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
727  return nullptr;
728  }
729  if (indicesArrayAttr.empty()) {
730  emitErrorFn("expected at least one index for spirv.CompositeExtract");
731  return nullptr;
732  }
733 
734  SmallVector<int32_t, 2> indexVals;
735  for (auto indexAttr : indicesArrayAttr) {
736  auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
737  if (!indexIntAttr) {
738  emitErrorFn("expected an 32-bit integer for index, but found '")
739  << indexAttr << "'";
740  return nullptr;
741  }
742  indexVals.push_back(indexIntAttr.getInt());
743  }
744  return getElementType(type, indexVals, emitErrorFn);
745 }
746 
747 static Type getElementType(Type type, Attribute indices, Location loc) {
748  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
749  return ::mlir::emitError(loc, err);
750  };
751  return getElementType(type, indices, errorFn);
752 }
753 
754 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
755  SMLoc loc) {
756  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
757  return parser.emitError(loc, err);
758  };
759  return getElementType(type, indices, errorFn);
760 }
761 
762 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
763 static inline bool isMergeBlock(Block &block) {
764  return !block.empty() && std::next(block.begin()) == block.end() &&
765  isa<spirv::MergeOp>(block.front());
766 }
767 
768 template <typename ExtendedBinaryOp>
769 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
770  auto resultType = op.getType().template cast<spirv::StructType>();
771  if (resultType.getNumElements() != 2)
772  return op.emitOpError("expected result struct type containing two members");
773 
774  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
775  resultType.getElementType(0),
776  resultType.getElementType(1)}))
777  return op.emitOpError(
778  "expected all operand types and struct member types are the same");
779 
780  return success();
781 }
782 
784  OperationState &result) {
786  if (parser.parseOptionalAttrDict(result.attributes) ||
787  parser.parseOperandList(operands) || parser.parseColon())
788  return failure();
789 
790  Type resultType;
791  SMLoc loc = parser.getCurrentLocation();
792  if (parser.parseType(resultType))
793  return failure();
794 
795  auto structType = resultType.dyn_cast<spirv::StructType>();
796  if (!structType || structType.getNumElements() != 2)
797  return parser.emitError(loc, "expected spirv.struct type with two members");
798 
799  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
800  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
801  return failure();
802 
803  result.addTypes(resultType);
804  return success();
805 }
806 
808  OpAsmPrinter &printer) {
809  printer << ' ';
810  printer.printOptionalAttrDict(op->getAttrs());
811  printer.printOperands(op->getOperands());
812  printer << " : " << op->getResultTypes().front();
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // Common parsers and printers
817 //===----------------------------------------------------------------------===//
818 
819 // Parses an atomic update op. If the update op does not take a value (like
820 // AtomicIIncrement) `hasValue` must be false.
822  OperationState &state, bool hasValue) {
823  spirv::Scope scope;
824  spirv::MemorySemantics memoryScope;
826  OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
827  Type type;
828  SMLoc loc;
829  if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
831  parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
833  parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
834  parser.getCurrentLocation(&loc) || parser.parseColonType(type))
835  return failure();
836 
837  auto ptrType = type.dyn_cast<spirv::PointerType>();
838  if (!ptrType)
839  return parser.emitError(loc, "expected pointer type");
840 
841  SmallVector<Type, 2> operandTypes;
842  operandTypes.push_back(ptrType);
843  if (hasValue)
844  operandTypes.push_back(ptrType.getPointeeType());
845  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
846  state.operands))
847  return failure();
848  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
849 }
850 
851 // Prints an atomic update op.
852 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
853  printer << " \"";
854  auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
855  printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
856  auto memorySemanticsAttr =
857  op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
858  printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
859  << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
860 }
861 
862 template <typename T>
863 static StringRef stringifyTypeName();
864 
865 template <>
867  return "integer";
868 }
869 
870 template <>
872  return "float";
873 }
874 
875 // Verifies an atomic update op.
876 template <typename ExpectedElementType>
878  auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
879  auto elementType = ptrType.getPointeeType();
880  if (!elementType.isa<ExpectedElementType>())
881  return op->emitOpError() << "pointer operand must point to an "
882  << stringifyTypeName<ExpectedElementType>()
883  << " value, found " << elementType;
884 
885  if (op->getNumOperands() > 1) {
886  auto valueType = op->getOperand(1).getType();
887  if (valueType != elementType)
888  return op->emitOpError("expected value to have the same type as the "
889  "pointer operand's pointee type ")
890  << elementType << ", but found " << valueType;
891  }
892  auto memorySemantics =
893  op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
894  .getValue();
895  if (failed(verifyMemorySemantics(op, memorySemantics))) {
896  return failure();
897  }
898  return success();
899 }
900 
902  OperationState &state) {
903  spirv::Scope executionScope;
904  spirv::GroupOperation groupOperation;
906  if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
908  parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
910  parser.parseOperand(valueInfo))
911  return failure();
912 
915  clusterSizeInfo = OpAsmParser::UnresolvedOperand();
916  if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
917  parser.parseRParen())
918  return failure();
919  }
920 
921  Type resultType;
922  if (parser.parseColonType(resultType))
923  return failure();
924 
925  if (parser.resolveOperand(valueInfo, resultType, state.operands))
926  return failure();
927 
928  if (clusterSizeInfo) {
929  Type i32Type = parser.getBuilder().getIntegerType(32);
930  if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
931  return failure();
932  }
933 
934  return parser.addTypeToList(resultType, state.types);
935 }
936 
938  OpAsmPrinter &printer) {
939  printer
940  << " \""
941  << stringifyScope(
942  groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
943  .getValue())
944  << "\" \""
945  << stringifyGroupOperation(groupOp
946  ->getAttrOfType<spirv::GroupOperationAttr>(
948  .getValue())
949  << "\" " << groupOp->getOperand(0);
950 
951  if (groupOp->getNumOperands() > 1)
952  printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
953  printer << " : " << groupOp->getResult(0).getType();
954 }
955 
957  spirv::Scope scope =
958  groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
959  .getValue();
960  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
961  return groupOp->emitOpError(
962  "execution scope must be 'Workgroup' or 'Subgroup'");
963 
964  spirv::GroupOperation operation =
965  groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
966  .getValue();
967  if (operation == spirv::GroupOperation::ClusteredReduce &&
968  groupOp->getNumOperands() == 1)
969  return groupOp->emitOpError("cluster size operand must be provided for "
970  "'ClusteredReduce' group operation");
971  if (groupOp->getNumOperands() > 1) {
972  Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
973  int32_t clusterSize = 0;
974 
975  // TODO: support specialization constant here.
976  if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
977  return groupOp->emitOpError(
978  "cluster size operand must come from a constant op");
979 
980  if (!llvm::isPowerOf2_32(clusterSize))
981  return groupOp->emitOpError(
982  "cluster size operand must be a power of two");
983  }
984  return success();
985 }
986 
987 /// Result of a logical op must be a scalar or vector of boolean type.
988 static Type getUnaryOpResultType(Type operandType) {
989  Builder builder(operandType.getContext());
990  Type resultType = builder.getIntegerType(1);
991  if (auto vecType = operandType.dyn_cast<VectorType>())
992  return VectorType::get(vecType.getNumElements(), resultType);
993  return resultType;
994 }
995 
997  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
998  return op->emitError("expected the same type for the first operand and "
999  "result, but provided ")
1000  << op->getOperand(0).getType() << " and "
1001  << op->getResult(0).getType();
1002  }
1003  return success();
1004 }
1005 
1006 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
1007  Value lhs, Value rhs) {
1008  assert(lhs.getType() == rhs.getType());
1009 
1010  Type boolType = builder.getI1Type();
1011  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
1012  boolType = VectorType::get(vecType.getShape(), boolType);
1013  state.addTypes(boolType);
1014 
1015  state.addOperands({lhs, rhs});
1016 }
1017 
1018 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
1019  Value value) {
1020  Type boolType = builder.getI1Type();
1021  if (auto vecType = value.getType().dyn_cast<VectorType>())
1022  boolType = VectorType::get(vecType.getShape(), boolType);
1023  state.addTypes(boolType);
1024 
1025  state.addOperands(value);
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // spirv.AccessChainOp
1030 //===----------------------------------------------------------------------===//
1031 
1032 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
1033  auto ptrType = type.dyn_cast<spirv::PointerType>();
1034  if (!ptrType) {
1035  emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
1036  "to composite type, but provided ")
1037  << type;
1038  return nullptr;
1039  }
1040 
1041  auto resultType = ptrType.getPointeeType();
1042  auto resultStorageClass = ptrType.getStorageClass();
1043  int32_t index = 0;
1044 
1045  for (auto indexSSA : indices) {
1046  auto cType = resultType.dyn_cast<spirv::CompositeType>();
1047  if (!cType) {
1048  emitError(
1049  baseLoc,
1050  "'spirv.AccessChain' op cannot extract from non-composite type ")
1051  << resultType << " with index " << index;
1052  return nullptr;
1053  }
1054  index = 0;
1055  if (resultType.isa<spirv::StructType>()) {
1056  Operation *op = indexSSA.getDefiningOp();
1057  if (!op) {
1058  emitError(baseLoc, "'spirv.AccessChain' op index must be an "
1059  "integer spirv.Constant to access "
1060  "element of spirv.struct");
1061  return nullptr;
1062  }
1063 
1064  // TODO: this should be relaxed to allow
1065  // integer literals of other bitwidths.
1066  if (failed(extractValueFromConstOp(op, index))) {
1067  emitError(
1068  baseLoc,
1069  "'spirv.AccessChain' index must be an integer spirv.Constant to "
1070  "access element of spirv.struct, but provided ")
1071  << op->getName();
1072  return nullptr;
1073  }
1074  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1075  emitError(baseLoc, "'spirv.AccessChain' op index ")
1076  << index << " out of bounds for " << resultType;
1077  return nullptr;
1078  }
1079  }
1080  resultType = cType.getElementType(index);
1081  }
1082  return spirv::PointerType::get(resultType, resultStorageClass);
1083 }
1084 
1085 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1086  Value basePtr, ValueRange indices) {
1087  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1088  assert(type && "Unable to deduce return type based on basePtr and indices");
1089  build(builder, state, type, basePtr, indices);
1090 }
1091 
1092 ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1093  OperationState &result) {
1096  Type type;
1097  auto loc = parser.getCurrentLocation();
1098  SmallVector<Type, 4> indicesTypes;
1099 
1100  if (parser.parseOperand(ptrInfo) ||
1101  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1102  parser.parseColonType(type) ||
1103  parser.resolveOperand(ptrInfo, type, result.operands)) {
1104  return failure();
1105  }
1106 
1107  // Check that the provided indices list is not empty before parsing their
1108  // type list.
1109  if (indicesInfo.empty()) {
1110  return mlir::emitError(result.location,
1111  "'spirv.AccessChain' op expected at "
1112  "least one index ");
1113  }
1114 
1115  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1116  return failure();
1117 
1118  // Check that the indices types list is not empty and that it has a one-to-one
1119  // mapping to the provided indices.
1120  if (indicesTypes.size() != indicesInfo.size()) {
1121  return mlir::emitError(
1122  result.location, "'spirv.AccessChain' op indices types' count must be "
1123  "equal to indices info count");
1124  }
1125 
1126  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
1127  return failure();
1128 
1129  auto resultType = getElementPtrType(
1130  type, llvm::makeArrayRef(result.operands).drop_front(), result.location);
1131  if (!resultType) {
1132  return failure();
1133  }
1134 
1135  result.addTypes(resultType);
1136  return success();
1137 }
1138 
1139 template <typename Op>
1140 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1141  printer << ' ' << op.getBasePtr() << '[' << indices
1142  << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
1143 }
1144 
1146  printAccessChain(*this, getIndices(), printer);
1147 }
1148 
1149 template <typename Op>
1150 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1151  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
1152  indices, accessChainOp.getLoc());
1153  if (!resultType)
1154  return failure();
1155 
1156  auto providedResultType =
1157  accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1158  if (!providedResultType)
1159  return accessChainOp.emitOpError(
1160  "result type must be a pointer, but provided")
1161  << providedResultType;
1162 
1163  if (resultType != providedResultType)
1164  return accessChainOp.emitOpError("invalid result type: expected ")
1165  << resultType << ", but provided " << providedResultType;
1166 
1167  return success();
1168 }
1169 
1171  return verifyAccessChain(*this, getIndices());
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // spirv.mlir.addressof
1176 //===----------------------------------------------------------------------===//
1177 
1178 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1179  spirv::GlobalVariableOp var) {
1180  build(builder, state, var.getType(), SymbolRefAttr::get(var));
1181 }
1182 
1184  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1185  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1186  getVariableAttr()));
1187  if (!varOp) {
1188  return emitOpError("expected spirv.GlobalVariable symbol");
1189  }
1190  if (getPointer().getType() != varOp.getType()) {
1191  return emitOpError(
1192  "result type mismatch with the referenced global variable's type");
1193  }
1194  return success();
1195 }
1196 
1197 template <typename T>
1198 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1199  printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
1200  << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
1201  << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
1202  << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
1203 }
1204 
1206  OperationState &state) {
1207  spirv::Scope memoryScope;
1208  spirv::MemorySemantics equalSemantics, unequalSemantics;
1210  Type type;
1211  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1213  parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1214  equalSemantics, parser, state, kEqualSemanticsAttrName) ||
1215  parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1216  unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
1217  parser.parseOperandList(operandInfo, 3))
1218  return failure();
1219 
1220  auto loc = parser.getCurrentLocation();
1221  if (parser.parseColonType(type))
1222  return failure();
1223 
1224  auto ptrType = type.dyn_cast<spirv::PointerType>();
1225  if (!ptrType)
1226  return parser.emitError(loc, "expected pointer type");
1227 
1228  if (parser.resolveOperands(
1229  operandInfo,
1230  {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1231  parser.getNameLoc(), state.operands))
1232  return failure();
1233 
1234  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1235 }
1236 
1237 template <typename T>
1239  // According to the spec:
1240  // "The type of Value must be the same as Result Type. The type of the value
1241  // pointed to by Pointer must be the same as Result Type. This type must also
1242  // match the type of Comparator."
1243  if (atomOp.getType() != atomOp.getValue().getType())
1244  return atomOp.emitOpError("value operand must have the same type as the op "
1245  "result, but found ")
1246  << atomOp.getValue().getType() << " vs " << atomOp.getType();
1247 
1248  if (atomOp.getType() != atomOp.getComparator().getType())
1249  return atomOp.emitOpError(
1250  "comparator operand must have the same type as the op "
1251  "result, but found ")
1252  << atomOp.getComparator().getType() << " vs " << atomOp.getType();
1253 
1254  Type pointeeType = atomOp.getPointer()
1255  .getType()
1256  .template cast<spirv::PointerType>()
1257  .getPointeeType();
1258  if (atomOp.getType() != pointeeType)
1259  return atomOp.emitOpError(
1260  "pointer operand's pointee type must have the same "
1261  "as the op result type, but found ")
1262  << pointeeType << " vs " << atomOp.getType();
1263 
1264  // TODO: Unequal cannot be set to Release or Acquire and Release.
1265  // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1266 
1267  return success();
1268 }
1269 
1270 //===----------------------------------------------------------------------===//
1271 // spirv.AtomicAndOp
1272 //===----------------------------------------------------------------------===//
1273 
1275  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1276 }
1277 
1278 ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1279  OperationState &result) {
1280  return ::parseAtomicUpdateOp(parser, result, true);
1281 }
1283  ::printAtomicUpdateOp(*this, p);
1284 }
1285 
1286 //===----------------------------------------------------------------------===//
1287 // spirv.AtomicCompareExchangeOp
1288 //===----------------------------------------------------------------------===//
1289 
1292 }
1293 
1294 ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1295  OperationState &result) {
1297 }
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // spirv.AtomicCompareExchangeWeakOp
1304 //===----------------------------------------------------------------------===//
1305 
1308 }
1309 
1310 ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1311  OperationState &result) {
1313 }
1316 }
1317 
1318 //===----------------------------------------------------------------------===//
1319 // spirv.AtomicExchange
1320 //===----------------------------------------------------------------------===//
1321 
1323  printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
1324  << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
1325  << " : " << getPointer().getType();
1326 }
1327 
1328 ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1329  OperationState &result) {
1330  spirv::Scope memoryScope;
1331  spirv::MemorySemantics semantics;
1333  Type type;
1334  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1336  parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1337  kSemanticsAttrName) ||
1338  parser.parseOperandList(operandInfo, 2))
1339  return failure();
1340 
1341  auto loc = parser.getCurrentLocation();
1342  if (parser.parseColonType(type))
1343  return failure();
1344 
1345  auto ptrType = type.dyn_cast<spirv::PointerType>();
1346  if (!ptrType)
1347  return parser.emitError(loc, "expected pointer type");
1348 
1349  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1350  parser.getNameLoc(), result.operands))
1351  return failure();
1352 
1353  return parser.addTypeToList(ptrType.getPointeeType(), result.types);
1354 }
1355 
1357  if (getType() != getValue().getType())
1358  return emitOpError("value operand must have the same type as the op "
1359  "result, but found ")
1360  << getValue().getType() << " vs " << getType();
1361 
1362  Type pointeeType =
1363  getPointer().getType().cast<spirv::PointerType>().getPointeeType();
1364  if (getType() != pointeeType)
1365  return emitOpError("pointer operand's pointee type must have the same "
1366  "as the op result type, but found ")
1367  << pointeeType << " vs " << getType();
1368 
1369  return success();
1370 }
1371 
1372 //===----------------------------------------------------------------------===//
1373 // spirv.AtomicIAddOp
1374 //===----------------------------------------------------------------------===//
1375 
1377  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1378 }
1379 
1380 ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1381  OperationState &result) {
1382  return ::parseAtomicUpdateOp(parser, result, true);
1383 }
1385  ::printAtomicUpdateOp(*this, p);
1386 }
1387 
1388 //===----------------------------------------------------------------------===//
1389 // spirv.EXT.AtomicFAddOp
1390 //===----------------------------------------------------------------------===//
1391 
1393  return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1394 }
1395 
1396 ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
1397  OperationState &result) {
1398  return ::parseAtomicUpdateOp(parser, result, true);
1399 }
1401  ::printAtomicUpdateOp(*this, p);
1402 }
1403 
1404 //===----------------------------------------------------------------------===//
1405 // spirv.AtomicIDecrementOp
1406 //===----------------------------------------------------------------------===//
1407 
1409  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1410 }
1411 
1412 ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1413  OperationState &result) {
1414  return ::parseAtomicUpdateOp(parser, result, false);
1415 }
1417  ::printAtomicUpdateOp(*this, p);
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // spirv.AtomicIIncrementOp
1422 //===----------------------------------------------------------------------===//
1423 
1425  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1426 }
1427 
1428 ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1429  OperationState &result) {
1430  return ::parseAtomicUpdateOp(parser, result, false);
1431 }
1433  ::printAtomicUpdateOp(*this, p);
1434 }
1435 
1436 //===----------------------------------------------------------------------===//
1437 // spirv.AtomicISubOp
1438 //===----------------------------------------------------------------------===//
1439 
1441  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1442 }
1443 
1444 ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1445  OperationState &result) {
1446  return ::parseAtomicUpdateOp(parser, result, true);
1447 }
1449  ::printAtomicUpdateOp(*this, p);
1450 }
1451 
1452 //===----------------------------------------------------------------------===//
1453 // spirv.AtomicOrOp
1454 //===----------------------------------------------------------------------===//
1455 
1457  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1458 }
1459 
1460 ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1461  OperationState &result) {
1462  return ::parseAtomicUpdateOp(parser, result, true);
1463 }
1465  ::printAtomicUpdateOp(*this, p);
1466 }
1467 
1468 //===----------------------------------------------------------------------===//
1469 // spirv.AtomicSMaxOp
1470 //===----------------------------------------------------------------------===//
1471 
1473  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1474 }
1475 
1476 ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1477  OperationState &result) {
1478  return ::parseAtomicUpdateOp(parser, result, true);
1479 }
1481  ::printAtomicUpdateOp(*this, p);
1482 }
1483 
1484 //===----------------------------------------------------------------------===//
1485 // spirv.AtomicSMinOp
1486 //===----------------------------------------------------------------------===//
1487 
1489  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1490 }
1491 
1492 ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1493  OperationState &result) {
1494  return ::parseAtomicUpdateOp(parser, result, true);
1495 }
1497  ::printAtomicUpdateOp(*this, p);
1498 }
1499 
1500 //===----------------------------------------------------------------------===//
1501 // spirv.AtomicUMaxOp
1502 //===----------------------------------------------------------------------===//
1503 
1505  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1506 }
1507 
1508 ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1509  OperationState &result) {
1510  return ::parseAtomicUpdateOp(parser, result, true);
1511 }
1513  ::printAtomicUpdateOp(*this, p);
1514 }
1515 
1516 //===----------------------------------------------------------------------===//
1517 // spirv.AtomicUMinOp
1518 //===----------------------------------------------------------------------===//
1519 
1521  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1522 }
1523 
1524 ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1525  OperationState &result) {
1526  return ::parseAtomicUpdateOp(parser, result, true);
1527 }
1529  ::printAtomicUpdateOp(*this, p);
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // spirv.AtomicXorOp
1534 //===----------------------------------------------------------------------===//
1535 
1537  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1538 }
1539 
1540 ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1541  OperationState &result) {
1542  return ::parseAtomicUpdateOp(parser, result, true);
1543 }
1545  ::printAtomicUpdateOp(*this, p);
1546 }
1547 
1548 //===----------------------------------------------------------------------===//
1549 // spirv.BitcastOp
1550 //===----------------------------------------------------------------------===//
1551 
1553  // TODO: The SPIR-V spec validation rules are different for different
1554  // versions.
1555  auto operandType = getOperand().getType();
1556  auto resultType = getResult().getType();
1557  if (operandType == resultType) {
1558  return emitError("result type must be different from operand type");
1559  }
1560  if (operandType.isa<spirv::PointerType>() &&
1561  !resultType.isa<spirv::PointerType>()) {
1562  return emitError(
1563  "unhandled bit cast conversion from pointer type to non-pointer type");
1564  }
1565  if (!operandType.isa<spirv::PointerType>() &&
1566  resultType.isa<spirv::PointerType>()) {
1567  return emitError(
1568  "unhandled bit cast conversion from non-pointer type to pointer type");
1569  }
1570  auto operandBitWidth = getBitWidth(operandType);
1571  auto resultBitWidth = getBitWidth(resultType);
1572  if (operandBitWidth != resultBitWidth) {
1573  return emitOpError("mismatch in result type bitwidth ")
1574  << resultBitWidth << " and operand type bitwidth "
1575  << operandBitWidth;
1576  }
1577  return success();
1578 }
1579 
1580 //===----------------------------------------------------------------------===//
1581 // spirv.PtrCastToGenericOp
1582 //===----------------------------------------------------------------------===//
1583 
1585  auto operandType = getPointer().getType().cast<spirv::PointerType>();
1586  auto resultType = getResult().getType().cast<spirv::PointerType>();
1587 
1588  spirv::StorageClass operandStorage = operandType.getStorageClass();
1589  if (operandStorage != spirv::StorageClass::Workgroup &&
1590  operandStorage != spirv::StorageClass::CrossWorkgroup &&
1591  operandStorage != spirv::StorageClass::Function)
1592  return emitError("pointer must point to the Workgroup, CrossWorkgroup"
1593  ", or Function Storage Class");
1594 
1595  spirv::StorageClass resultStorage = resultType.getStorageClass();
1596  if (resultStorage != spirv::StorageClass::Generic)
1597  return emitError("result type must be of storage class Generic");
1598 
1599  Type operandPointeeType = operandType.getPointeeType();
1600  Type resultPointeeType = resultType.getPointeeType();
1601  if (operandPointeeType != resultPointeeType)
1602  return emitOpError("pointer operand's pointee type must have the same "
1603  "as the op result type, but found ")
1604  << operandPointeeType << " vs " << resultPointeeType;
1605  return success();
1606 }
1607 
1608 //===----------------------------------------------------------------------===//
1609 // spirv.GenericCastToPtrOp
1610 //===----------------------------------------------------------------------===//
1611 
1613  auto operandType = getPointer().getType().cast<spirv::PointerType>();
1614  auto resultType = getResult().getType().cast<spirv::PointerType>();
1615 
1616  spirv::StorageClass operandStorage = operandType.getStorageClass();
1617  if (operandStorage != spirv::StorageClass::Generic)
1618  return emitError("pointer type must be of storage class Generic");
1619 
1620  spirv::StorageClass resultStorage = resultType.getStorageClass();
1621  if (resultStorage != spirv::StorageClass::Workgroup &&
1622  resultStorage != spirv::StorageClass::CrossWorkgroup &&
1623  resultStorage != spirv::StorageClass::Function)
1624  return emitError("result must point to the Workgroup, CrossWorkgroup, "
1625  "or Function Storage Class");
1626 
1627  Type operandPointeeType = operandType.getPointeeType();
1628  Type resultPointeeType = resultType.getPointeeType();
1629  if (operandPointeeType != resultPointeeType)
1630  return emitOpError("pointer operand's pointee type must have the same "
1631  "as the op result type, but found ")
1632  << operandPointeeType << " vs " << resultPointeeType;
1633  return success();
1634 }
1635 
1636 //===----------------------------------------------------------------------===//
1637 // spirv.GenericCastToPtrExplicitOp
1638 //===----------------------------------------------------------------------===//
1639 
1641  auto operandType = getPointer().getType().cast<spirv::PointerType>();
1642  auto resultType = getResult().getType().cast<spirv::PointerType>();
1643 
1644  spirv::StorageClass operandStorage = operandType.getStorageClass();
1645  if (operandStorage != spirv::StorageClass::Generic)
1646  return emitError("pointer type must be of storage class Generic");
1647 
1648  spirv::StorageClass resultStorage = resultType.getStorageClass();
1649  if (resultStorage != spirv::StorageClass::Workgroup &&
1650  resultStorage != spirv::StorageClass::CrossWorkgroup &&
1651  resultStorage != spirv::StorageClass::Function)
1652  return emitError("result must point to the Workgroup, CrossWorkgroup, "
1653  "or Function Storage Class");
1654 
1655  Type operandPointeeType = operandType.getPointeeType();
1656  Type resultPointeeType = resultType.getPointeeType();
1657  if (operandPointeeType != resultPointeeType)
1658  return emitOpError("pointer operand's pointee type must have the same "
1659  "as the op result type, but found ")
1660  << operandPointeeType << " vs " << resultPointeeType;
1661  return success();
1662 }
1663 
1664 //===----------------------------------------------------------------------===//
1665 // spirv.BranchOp
1666 //===----------------------------------------------------------------------===//
1667 
1668 SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1669  assert(index == 0 && "invalid successor index");
1670  return SuccessorOperands(0, getTargetOperandsMutable());
1671 }
1672 
1673 //===----------------------------------------------------------------------===//
1674 // spirv.BranchConditionalOp
1675 //===----------------------------------------------------------------------===//
1676 
1678 spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1679  assert(index < 2 && "invalid successor index");
1680  return SuccessorOperands(index == kTrueIndex
1681  ? getTrueTargetOperandsMutable()
1682  : getFalseTargetOperandsMutable());
1683 }
1684 
1685 ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1686  OperationState &result) {
1687  auto &builder = parser.getBuilder();
1689  Block *dest;
1690 
1691  // Parse the condition.
1692  Type boolTy = builder.getI1Type();
1693  if (parser.parseOperand(condInfo) ||
1694  parser.resolveOperand(condInfo, boolTy, result.operands))
1695  return failure();
1696 
1697  // Parse the optional branch weights.
1698  if (succeeded(parser.parseOptionalLSquare())) {
1699  IntegerAttr trueWeight, falseWeight;
1700  NamedAttrList weights;
1701 
1702  auto i32Type = builder.getIntegerType(32);
1703  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1704  parser.parseComma() ||
1705  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1706  parser.parseRSquare())
1707  return failure();
1708 
1710  builder.getArrayAttr({trueWeight, falseWeight}));
1711  }
1712 
1713  // Parse the true branch.
1714  SmallVector<Value, 4> trueOperands;
1715  if (parser.parseComma() ||
1716  parser.parseSuccessorAndUseList(dest, trueOperands))
1717  return failure();
1718  result.addSuccessors(dest);
1719  result.addOperands(trueOperands);
1720 
1721  // Parse the false branch.
1722  SmallVector<Value, 4> falseOperands;
1723  if (parser.parseComma() ||
1724  parser.parseSuccessorAndUseList(dest, falseOperands))
1725  return failure();
1726  result.addSuccessors(dest);
1727  result.addOperands(falseOperands);
1728  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1729  builder.getDenseI32ArrayAttr(
1730  {1, static_cast<int32_t>(trueOperands.size()),
1731  static_cast<int32_t>(falseOperands.size())}));
1732 
1733  return success();
1734 }
1735 
1737  printer << ' ' << getCondition();
1738 
1739  if (auto weights = getBranchWeights()) {
1740  printer << " [";
1741  llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1742  printer << a.cast<IntegerAttr>().getInt();
1743  });
1744  printer << "]";
1745  }
1746 
1747  printer << ", ";
1748  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1749  printer << ", ";
1750  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1751 }
1752 
1754  if (auto weights = getBranchWeights()) {
1755  if (weights->getValue().size() != 2) {
1756  return emitOpError("must have exactly two branch weights");
1757  }
1758  if (llvm::all_of(*weights, [](Attribute attr) {
1759  return attr.cast<IntegerAttr>().getValue().isNullValue();
1760  }))
1761  return emitOpError("branch weights cannot both be zero");
1762  }
1763 
1764  return success();
1765 }
1766 
1767 //===----------------------------------------------------------------------===//
1768 // spirv.CompositeConstruct
1769 //===----------------------------------------------------------------------===//
1770 
1772  auto cType = getType().cast<spirv::CompositeType>();
1773  operand_range constituents = this->getConstituents();
1774 
1775  if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1776  if (constituents.size() != 1)
1777  return emitOpError("has incorrect number of operands: expected ")
1778  << "1, but provided " << constituents.size();
1779  if (coopType.getElementType() != constituents.front().getType())
1780  return emitOpError("operand type mismatch: expected operand type ")
1781  << coopType.getElementType() << ", but provided "
1782  << constituents.front().getType();
1783  return success();
1784  }
1785 
1786  if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
1787  if (constituents.size() != 1)
1788  return emitOpError("has incorrect number of operands: expected ")
1789  << "1, but provided " << constituents.size();
1790  if (jointType.getElementType() != constituents.front().getType())
1791  return emitOpError("operand type mismatch: expected operand type ")
1792  << jointType.getElementType() << ", but provided "
1793  << constituents.front().getType();
1794  return success();
1795  }
1796 
1797  if (constituents.size() == cType.getNumElements()) {
1798  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1799  if (constituents[index].getType() != cType.getElementType(index)) {
1800  return emitOpError("operand type mismatch: expected operand type ")
1801  << cType.getElementType(index) << ", but provided "
1802  << constituents[index].getType();
1803  }
1804  }
1805  return success();
1806  }
1807 
1808  // If not constructing a cooperative matrix type, then we must be constructing
1809  // a vector type.
1810  auto resultType = cType.dyn_cast<VectorType>();
1811  if (!resultType)
1812  return emitOpError(
1813  "expected to return a vector or cooperative matrix when the number of "
1814  "constituents is less than what the result needs");
1815 
1816  SmallVector<unsigned> sizes;
1817  for (Value component : constituents) {
1818  if (!component.getType().isa<VectorType>() &&
1819  !component.getType().isIntOrFloat())
1820  return emitOpError("operand type mismatch: expected operand to have "
1821  "a scalar or vector type, but provided ")
1822  << component.getType();
1823 
1824  Type elementType = component.getType();
1825  if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
1826  sizes.push_back(vectorType.getNumElements());
1827  elementType = vectorType.getElementType();
1828  } else {
1829  sizes.push_back(1);
1830  }
1831 
1832  if (elementType != resultType.getElementType())
1833  return emitOpError("operand element type mismatch: expected to be ")
1834  << resultType.getElementType() << ", but provided " << elementType;
1835  }
1836  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1837  if (totalCount != cType.getNumElements())
1838  return emitOpError("has incorrect number of operands: expected ")
1839  << cType.getNumElements() << ", but provided " << totalCount;
1840  return success();
1841 }
1842 
1843 //===----------------------------------------------------------------------===//
1844 // spirv.CompositeExtractOp
1845 //===----------------------------------------------------------------------===//
1846 
1847 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1848  Value composite,
1849  ArrayRef<int32_t> indices) {
1850  auto indexAttr = builder.getI32ArrayAttr(indices);
1851  auto elementType =
1852  getElementType(composite.getType(), indexAttr, state.location);
1853  if (!elementType) {
1854  return;
1855  }
1856  build(builder, state, elementType, composite, indexAttr);
1857 }
1858 
1859 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1860  OperationState &result) {
1861  OpAsmParser::UnresolvedOperand compositeInfo;
1862  Attribute indicesAttr;
1863  Type compositeType;
1864  SMLoc attrLocation;
1865 
1866  if (parser.parseOperand(compositeInfo) ||
1867  parser.getCurrentLocation(&attrLocation) ||
1868  parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1869  parser.parseColonType(compositeType) ||
1870  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
1871  return failure();
1872  }
1873 
1874  Type resultType =
1875  getElementType(compositeType, indicesAttr, parser, attrLocation);
1876  if (!resultType) {
1877  return failure();
1878  }
1879  result.addTypes(resultType);
1880  return success();
1881 }
1882 
1884  printer << ' ' << getComposite() << getIndices() << " : "
1885  << getComposite().getType();
1886 }
1887 
1889  auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1890  auto resultType =
1891  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1892  if (!resultType)
1893  return failure();
1894 
1895  if (resultType != getType()) {
1896  return emitOpError("invalid result type: expected ")
1897  << resultType << " but provided " << getType();
1898  }
1899 
1900  return success();
1901 }
1902 
1903 //===----------------------------------------------------------------------===//
1904 // spirv.CompositeInsert
1905 //===----------------------------------------------------------------------===//
1906 
1907 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1908  Value object, Value composite,
1909  ArrayRef<int32_t> indices) {
1910  auto indexAttr = builder.getI32ArrayAttr(indices);
1911  build(builder, state, composite.getType(), object, composite, indexAttr);
1912 }
1913 
1914 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1915  OperationState &result) {
1917  Type objectType, compositeType;
1918  Attribute indicesAttr;
1919  auto loc = parser.getCurrentLocation();
1920 
1921  return failure(
1922  parser.parseOperandList(operands, 2) ||
1923  parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1924  parser.parseColonType(objectType) ||
1925  parser.parseKeywordType("into", compositeType) ||
1926  parser.resolveOperands(operands, {objectType, compositeType}, loc,
1927  result.operands) ||
1928  parser.addTypesToList(compositeType, result.types));
1929 }
1930 
1932  auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1933  auto objectType =
1934  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1935  if (!objectType)
1936  return failure();
1937 
1938  if (objectType != getObject().getType()) {
1939  return emitOpError("object operand type should be ")
1940  << objectType << ", but found " << getObject().getType();
1941  }
1942 
1943  if (getComposite().getType() != getType()) {
1944  return emitOpError("result type should be the same as "
1945  "the composite type, but found ")
1946  << getComposite().getType() << " vs " << getType();
1947  }
1948 
1949  return success();
1950 }
1951 
1953  printer << " " << getObject() << ", " << getComposite() << getIndices()
1954  << " : " << getObject().getType() << " into "
1955  << getComposite().getType();
1956 }
1957 
1958 //===----------------------------------------------------------------------===//
1959 // spirv.Constant
1960 //===----------------------------------------------------------------------===//
1961 
1962 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1963  OperationState &result) {
1964  Attribute value;
1965  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
1966  return failure();
1967 
1968  Type type = NoneType::get(parser.getContext());
1969  if (auto typedAttr = value.dyn_cast<TypedAttr>())
1970  type = typedAttr.getType();
1971  if (type.isa<NoneType, TensorType>()) {
1972  if (parser.parseColonType(type))
1973  return failure();
1974  }
1975 
1976  return parser.addTypeToList(type, result.types);
1977 }
1978 
1979 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1980  printer << ' ' << getValue();
1981  if (getType().isa<spirv::ArrayType>())
1982  printer << " : " << getType();
1983 }
1984 
1985 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1986  Type opType) {
1987  if (value.isa<IntegerAttr, FloatAttr>()) {
1988  auto valueType = value.cast<TypedAttr>().getType();
1989  if (valueType != opType)
1990  return op.emitOpError("result type (")
1991  << opType << ") does not match value type (" << valueType << ")";
1992  return success();
1993  }
1994  if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1995  auto valueType = value.cast<TypedAttr>().getType();
1996  if (valueType == opType)
1997  return success();
1998  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1999  auto shapedType = valueType.dyn_cast<ShapedType>();
2000  if (!arrayType)
2001  return op.emitOpError("result or element type (")
2002  << opType << ") does not match value type (" << valueType
2003  << "), must be the same or spirv.array";
2004 
2005  int numElements = arrayType.getNumElements();
2006  auto opElemType = arrayType.getElementType();
2007  while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
2008  numElements *= t.getNumElements();
2009  opElemType = t.getElementType();
2010  }
2011  if (!opElemType.isIntOrFloat())
2012  return op.emitOpError("only support nested array result type");
2013 
2014  auto valueElemType = shapedType.getElementType();
2015  if (valueElemType != opElemType) {
2016  return op.emitOpError("result element type (")
2017  << opElemType << ") does not match value element type ("
2018  << valueElemType << ")";
2019  }
2020 
2021  if (numElements != shapedType.getNumElements()) {
2022  return op.emitOpError("result number of elements (")
2023  << numElements << ") does not match value number of elements ("
2024  << shapedType.getNumElements() << ")";
2025  }
2026  return success();
2027  }
2028  if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
2029  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
2030  if (!arrayType)
2031  return op.emitOpError(
2032  "must have spirv.array result type for array value");
2033  Type elemType = arrayType.getElementType();
2034  for (Attribute element : arrayAttr.getValue()) {
2035  // Verify array elements recursively.
2036  if (failed(verifyConstantType(op, element, elemType)))
2037  return failure();
2038  }
2039  return success();
2040  }
2041  return op.emitOpError("cannot have attribute: ") << value;
2042 }
2043 
2045  // ODS already generates checks to make sure the result type is valid. We just
2046  // need to additionally check that the value's attribute type is consistent
2047  // with the result type.
2048  return verifyConstantType(*this, getValueAttr(), getType());
2049 }
2050 
2051 bool spirv::ConstantOp::isBuildableWith(Type type) {
2052  // Must be valid SPIR-V type first.
2053  if (!type.isa<spirv::SPIRVType>())
2054  return false;
2055 
2056  if (isa<SPIRVDialect>(type.getDialect())) {
2057  // TODO: support constant struct
2058  return type.isa<spirv::ArrayType>();
2059  }
2060 
2061  return true;
2062 }
2063 
2064 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
2065  OpBuilder &builder) {
2066  if (auto intType = type.dyn_cast<IntegerType>()) {
2067  unsigned width = intType.getWidth();
2068  if (width == 1)
2069  return builder.create<spirv::ConstantOp>(loc, type,
2070  builder.getBoolAttr(false));
2071  return builder.create<spirv::ConstantOp>(
2072  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
2073  }
2074  if (auto floatType = type.dyn_cast<FloatType>()) {
2075  return builder.create<spirv::ConstantOp>(
2076  loc, type, builder.getFloatAttr(floatType, 0.0));
2077  }
2078  if (auto vectorType = type.dyn_cast<VectorType>()) {
2079  Type elemType = vectorType.getElementType();
2080  if (elemType.isa<IntegerType>()) {
2081  return builder.create<spirv::ConstantOp>(
2082  loc, type,
2083  DenseElementsAttr::get(vectorType,
2084  IntegerAttr::get(elemType, 0).getValue()));
2085  }
2086  if (elemType.isa<FloatType>()) {
2087  return builder.create<spirv::ConstantOp>(
2088  loc, type,
2089  DenseFPElementsAttr::get(vectorType,
2090  FloatAttr::get(elemType, 0.0).getValue()));
2091  }
2092  }
2093 
2094  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
2095 }
2096 
2097 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
2098  OpBuilder &builder) {
2099  if (auto intType = type.dyn_cast<IntegerType>()) {
2100  unsigned width = intType.getWidth();
2101  if (width == 1)
2102  return builder.create<spirv::ConstantOp>(loc, type,
2103  builder.getBoolAttr(true));
2104  return builder.create<spirv::ConstantOp>(
2105  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
2106  }
2107  if (auto floatType = type.dyn_cast<FloatType>()) {
2108  return builder.create<spirv::ConstantOp>(
2109  loc, type, builder.getFloatAttr(floatType, 1.0));
2110  }
2111  if (auto vectorType = type.dyn_cast<VectorType>()) {
2112  Type elemType = vectorType.getElementType();
2113  if (elemType.isa<IntegerType>()) {
2114  return builder.create<spirv::ConstantOp>(
2115  loc, type,
2116  DenseElementsAttr::get(vectorType,
2117  IntegerAttr::get(elemType, 1).getValue()));
2118  }
2119  if (elemType.isa<FloatType>()) {
2120  return builder.create<spirv::ConstantOp>(
2121  loc, type,
2122  DenseFPElementsAttr::get(vectorType,
2123  FloatAttr::get(elemType, 1.0).getValue()));
2124  }
2125  }
2126 
2127  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
2128 }
2129 
2130 void mlir::spirv::ConstantOp::getAsmResultNames(
2131  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2132  Type type = getType();
2133 
2134  SmallString<32> specialNameBuffer;
2135  llvm::raw_svector_ostream specialName(specialNameBuffer);
2136  specialName << "cst";
2137 
2138  IntegerType intTy = type.dyn_cast<IntegerType>();
2139 
2140  if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
2141  if (intTy && intTy.getWidth() == 1) {
2142  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
2143  }
2144 
2145  if (intTy.isSignless()) {
2146  specialName << intCst.getInt();
2147  } else {
2148  specialName << intCst.getSInt();
2149  }
2150  }
2151 
2152  if (intTy || type.isa<FloatType>()) {
2153  specialName << '_' << type;
2154  }
2155 
2156  if (auto vecType = type.dyn_cast<VectorType>()) {
2157  specialName << "_vec_";
2158  specialName << vecType.getDimSize(0);
2159 
2160  Type elementType = vecType.getElementType();
2161 
2162  if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
2163  specialName << "x" << elementType;
2164  }
2165  }
2166 
2167  setNameFn(getResult(), specialName.str());
2168 }
2169 
2170 void mlir::spirv::AddressOfOp::getAsmResultNames(
2171  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2172  SmallString<32> specialNameBuffer;
2173  llvm::raw_svector_ostream specialName(specialNameBuffer);
2174  specialName << getVariable() << "_addr";
2175  setNameFn(getResult(), specialName.str());
2176 }
2177 
2178 //===----------------------------------------------------------------------===//
2179 // spirv.ControlBarrierOp
2180 //===----------------------------------------------------------------------===//
2181 
2183  return verifyMemorySemantics(getOperation(), getMemorySemantics());
2184 }
2185 
2186 //===----------------------------------------------------------------------===//
2187 // spirv.ConvertFToSOp
2188 //===----------------------------------------------------------------------===//
2189 
2191  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2192  /*skipBitWidthCheck=*/true);
2193 }
2194 
2195 //===----------------------------------------------------------------------===//
2196 // spirv.ConvertFToUOp
2197 //===----------------------------------------------------------------------===//
2198 
2200  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2201  /*skipBitWidthCheck=*/true);
2202 }
2203 
2204 //===----------------------------------------------------------------------===//
2205 // spirv.ConvertSToFOp
2206 //===----------------------------------------------------------------------===//
2207 
2209  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2210  /*skipBitWidthCheck=*/true);
2211 }
2212 
2213 //===----------------------------------------------------------------------===//
2214 // spirv.ConvertUToFOp
2215 //===----------------------------------------------------------------------===//
2216 
2218  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2219  /*skipBitWidthCheck=*/true);
2220 }
2221 
2222 //===----------------------------------------------------------------------===//
2223 // spirv.EntryPoint
2224 //===----------------------------------------------------------------------===//
2225 
2226 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2227  spirv::ExecutionModel executionModel,
2228  spirv::FuncOp function,
2229  ArrayRef<Attribute> interfaceVars) {
2230  build(builder, state,
2231  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2232  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2233 }
2234 
2235 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2236  OperationState &result) {
2237  spirv::ExecutionModel execModel;
2239  SmallVector<Type, 0> idTypes;
2240  SmallVector<Attribute, 4> interfaceVars;
2241 
2242  FlatSymbolRefAttr fn;
2243  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2244  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
2245  return failure();
2246  }
2247 
2248  if (!parser.parseOptionalComma()) {
2249  // Parse the interface variables
2250  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2251  // The name of the interface variable attribute isnt important
2252  FlatSymbolRefAttr var;
2253  NamedAttrList attrs;
2254  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2255  return failure();
2256  interfaceVars.push_back(var);
2257  return success();
2258  }))
2259  return failure();
2260  }
2262  parser.getBuilder().getArrayAttr(interfaceVars));
2263  return success();
2264 }
2265 
2267  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
2268  printer.printSymbolName(getFn());
2269  auto interfaceVars = getInterface().getValue();
2270  if (!interfaceVars.empty()) {
2271  printer << ", ";
2272  llvm::interleaveComma(interfaceVars, printer);
2273  }
2274 }
2275 
2277  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2278  // verification.
2279  return success();
2280 }
2281 
2282 //===----------------------------------------------------------------------===//
2283 // spirv.ExecutionMode
2284 //===----------------------------------------------------------------------===//
2285 
2286 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2287  spirv::FuncOp function,
2288  spirv::ExecutionMode executionMode,
2289  ArrayRef<int32_t> params) {
2290  build(builder, state, SymbolRefAttr::get(function),
2291  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2292  builder.getI32ArrayAttr(params));
2293 }
2294 
2295 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2296  OperationState &result) {
2297  spirv::ExecutionMode execMode;
2298  Attribute fn;
2299  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
2300  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2301  return failure();
2302  }
2303 
2304  SmallVector<int32_t, 4> values;
2305  Type i32Type = parser.getBuilder().getIntegerType(32);
2306  while (!parser.parseOptionalComma()) {
2307  NamedAttrList attr;
2308  Attribute value;
2309  if (parser.parseAttribute(value, i32Type, "value", attr)) {
2310  return failure();
2311  }
2312  values.push_back(value.cast<IntegerAttr>().getInt());
2313  }
2315  parser.getBuilder().getI32ArrayAttr(values));
2316  return success();
2317 }
2318 
2320  printer << " ";
2321  printer.printSymbolName(getFn());
2322  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
2323  auto values = this->getValues();
2324  if (values.empty())
2325  return;
2326  printer << ", ";
2327  llvm::interleaveComma(values, printer, [&](Attribute a) {
2328  printer << a.cast<IntegerAttr>().getInt();
2329  });
2330 }
2331 
2332 //===----------------------------------------------------------------------===//
2333 // spirv.FConvertOp
2334 //===----------------------------------------------------------------------===//
2335 
2337  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2338 }
2339 
2340 //===----------------------------------------------------------------------===//
2341 // spirv.SConvertOp
2342 //===----------------------------------------------------------------------===//
2343 
2345  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2346 }
2347 
2348 //===----------------------------------------------------------------------===//
2349 // spirv.UConvertOp
2350 //===----------------------------------------------------------------------===//
2351 
2353  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2354 }
2355 
2356 //===----------------------------------------------------------------------===//
2357 // spirv.func
2358 //===----------------------------------------------------------------------===//
2359 
2360 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2362  SmallVector<DictionaryAttr> resultAttrs;
2363  SmallVector<Type> resultTypes;
2364  auto &builder = parser.getBuilder();
2365 
2366  // Parse the name as a symbol.
2367  StringAttr nameAttr;
2368  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2369  result.attributes))
2370  return failure();
2371 
2372  // Parse the function signature.
2373  bool isVariadic = false;
2375  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2376  resultAttrs))
2377  return failure();
2378 
2379  SmallVector<Type> argTypes;
2380  for (auto &arg : entryArgs)
2381  argTypes.push_back(arg.type);
2382  auto fnType = builder.getFunctionType(argTypes, resultTypes);
2384  TypeAttr::get(fnType));
2385 
2386  // Parse the optional function control keyword.
2387  spirv::FunctionControl fnControl;
2388  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2389  return failure();
2390 
2391  // If additional attributes are present, parse them.
2392  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2393  return failure();
2394 
2395  // Add the attributes to the function arguments.
2396  assert(resultAttrs.size() == resultTypes.size());
2397  function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
2398  resultAttrs);
2399 
2400  // Parse the optional function body.
2401  auto *body = result.addRegion();
2402  OptionalParseResult parseResult =
2403  parser.parseOptionalRegion(*body, entryArgs);
2404  return failure(parseResult.has_value() && failed(*parseResult));
2405 }
2406 
2407 void spirv::FuncOp::print(OpAsmPrinter &printer) {
2408  // Print function name, signature, and control.
2409  printer << " ";
2410  printer.printSymbolName(getSymName());
2411  auto fnType = getFunctionType();
2413  printer, *this, fnType.getInputs(),
2414  /*isVariadic=*/false, fnType.getResults());
2415  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
2416  << "\"";
2418  printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
2419  {spirv::attributeName<spirv::FunctionControl>()});
2420 
2421  // Print the body if this is not an external function.
2422  Region &body = this->getBody();
2423  if (!body.empty()) {
2424  printer << ' ';
2425  printer.printRegion(body, /*printEntryBlockArgs=*/false,
2426  /*printBlockTerminators=*/true);
2427  }
2428 }
2429 
2430 LogicalResult spirv::FuncOp::verifyType() {
2431  auto type = getFunctionTypeAttr().getValue();
2432  if (!type.isa<FunctionType>())
2433  return emitOpError("requires '" + getTypeAttrName() +
2434  "' attribute of function type");
2435  if (getFunctionType().getNumResults() > 1)
2436  return emitOpError("cannot have more than one result");
2437  return success();
2438 }
2439 
2440 LogicalResult spirv::FuncOp::verifyBody() {
2441  FunctionType fnType = getFunctionType();
2442 
2443  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2444  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2445  if (fnType.getNumResults() != 0)
2446  return retOp.emitOpError("cannot be used in functions returning value");
2447  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2448  if (fnType.getNumResults() != 1)
2449  return retOp.emitOpError(
2450  "returns 1 value but enclosing function requires ")
2451  << fnType.getNumResults() << " results";
2452 
2453  auto retOperandType = retOp.getValue().getType();
2454  auto fnResultType = fnType.getResult(0);
2455  if (retOperandType != fnResultType)
2456  return retOp.emitOpError(" return value's type (")
2457  << retOperandType << ") mismatch with function's result type ("
2458  << fnResultType << ")";
2459  }
2460  return WalkResult::advance();
2461  });
2462 
2463  // TODO: verify other bits like linkage type.
2464 
2465  return failure(walkResult.wasInterrupted());
2466 }
2467 
2468 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2469  StringRef name, FunctionType type,
2470  spirv::FunctionControl control,
2471  ArrayRef<NamedAttribute> attrs) {
2473  builder.getStringAttr(name));
2474  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2475  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2476  builder.getAttr<spirv::FunctionControlAttr>(control));
2477  state.attributes.append(attrs.begin(), attrs.end());
2478  state.addRegion();
2479 }
2480 
2481 // CallableOpInterface
2482 Region *spirv::FuncOp::getCallableRegion() {
2483  return isExternal() ? nullptr : &getBody();
2484 }
2485 
2486 // CallableOpInterface
2487 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2488  return getFunctionType().getResults();
2489 }
2490 
2491 //===----------------------------------------------------------------------===//
2492 // spirv.FunctionCall
2493 //===----------------------------------------------------------------------===//
2494 
2496  auto fnName = getCalleeAttr();
2497 
2498  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2499  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2500  if (!funcOp) {
2501  return emitOpError("callee function '")
2502  << fnName.getValue() << "' not found in nearest symbol table";
2503  }
2504 
2505  auto functionType = funcOp.getFunctionType();
2506 
2507  if (getNumResults() > 1) {
2508  return emitOpError(
2509  "expected callee function to have 0 or 1 result, but provided ")
2510  << getNumResults();
2511  }
2512 
2513  if (functionType.getNumInputs() != getNumOperands()) {
2514  return emitOpError("has incorrect number of operands for callee: expected ")
2515  << functionType.getNumInputs() << ", but provided "
2516  << getNumOperands();
2517  }
2518 
2519  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2520  if (getOperand(i).getType() != functionType.getInput(i)) {
2521  return emitOpError("operand type mismatch: expected operand type ")
2522  << functionType.getInput(i) << ", but provided "
2523  << getOperand(i).getType() << " for operand number " << i;
2524  }
2525  }
2526 
2527  if (functionType.getNumResults() != getNumResults()) {
2528  return emitOpError(
2529  "has incorrect number of results has for callee: expected ")
2530  << functionType.getNumResults() << ", but provided "
2531  << getNumResults();
2532  }
2533 
2534  if (getNumResults() &&
2535  (getResult(0).getType() != functionType.getResult(0))) {
2536  return emitOpError("result type mismatch: expected ")
2537  << functionType.getResult(0) << ", but provided "
2538  << getResult(0).getType();
2539  }
2540 
2541  return success();
2542 }
2543 
2544 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2545  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2546 }
2547 
2548 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2549  return getArguments();
2550 }
2551 
2552 //===----------------------------------------------------------------------===//
2553 // spirv.GLFClampOp
2554 //===----------------------------------------------------------------------===//
2555 
2556 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2557  OperationState &result) {
2558  return parseOneResultSameOperandTypeOp(parser, result);
2559 }
2561 
2562 //===----------------------------------------------------------------------===//
2563 // spirv.GLUClampOp
2564 //===----------------------------------------------------------------------===//
2565 
2566 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2567  OperationState &result) {
2568  return parseOneResultSameOperandTypeOp(parser, result);
2569 }
2571 
2572 //===----------------------------------------------------------------------===//
2573 // spirv.GLSClampOp
2574 //===----------------------------------------------------------------------===//
2575 
2576 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2577  OperationState &result) {
2578  return parseOneResultSameOperandTypeOp(parser, result);
2579 }
2581 
2582 //===----------------------------------------------------------------------===//
2583 // spirv.GLFmaOp
2584 //===----------------------------------------------------------------------===//
2585 
2586 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2587  return parseOneResultSameOperandTypeOp(parser, result);
2588 }
2589 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2590 
2591 //===----------------------------------------------------------------------===//
2592 // spirv.GlobalVariable
2593 //===----------------------------------------------------------------------===//
2594 
2595 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2596  Type type, StringRef name,
2597  unsigned descriptorSet, unsigned binding) {
2598  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2599  state.addAttribute(
2600  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2601  builder.getI32IntegerAttr(descriptorSet));
2602  state.addAttribute(
2603  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2604  builder.getI32IntegerAttr(binding));
2605 }
2606 
2607 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2608  Type type, StringRef name,
2609  spirv::BuiltIn builtin) {
2610  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2611  state.addAttribute(
2612  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2613  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2614 }
2615 
2616 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2617  OperationState &result) {
2618  // Parse variable name.
2619  StringAttr nameAttr;
2620  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2621  result.attributes)) {
2622  return failure();
2623  }
2624 
2625  // Parse optional initializer
2627  FlatSymbolRefAttr initSymbol;
2628  if (parser.parseLParen() ||
2629  parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2630  result.attributes) ||
2631  parser.parseRParen())
2632  return failure();
2633  }
2634 
2635  if (parseVariableDecorations(parser, result)) {
2636  return failure();
2637  }
2638 
2639  Type type;
2640  auto loc = parser.getCurrentLocation();
2641  if (parser.parseColonType(type)) {
2642  return failure();
2643  }
2644  if (!type.isa<spirv::PointerType>()) {
2645  return parser.emitError(loc, "expected spirv.ptr type");
2646  }
2647  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
2648 
2649  return success();
2650 }
2651 
2653  SmallVector<StringRef, 4> elidedAttrs{
2654  spirv::attributeName<spirv::StorageClass>()};
2655 
2656  // Print variable name.
2657  printer << ' ';
2658  printer.printSymbolName(getSymName());
2659  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2660 
2661  // Print optional initializer
2662  if (auto initializer = this->getInitializer()) {
2663  printer << " " << kInitializerAttrName << '(';
2664  printer.printSymbolName(*initializer);
2665  printer << ')';
2666  elidedAttrs.push_back(kInitializerAttrName);
2667  }
2668 
2669  elidedAttrs.push_back(kTypeAttrName);
2670  printVariableDecorations(*this, printer, elidedAttrs);
2671  printer << " : " << getType();
2672 }
2673 
2675  if (!getType().isa<spirv::PointerType>())
2676  return emitOpError("result must be of a !spv.ptr type");
2677 
2678  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2679  // object. It cannot be Generic. It must be the same as the Storage Class
2680  // operand of the Result Type."
2681  // Also, Function storage class is reserved by spirv.Variable.
2682  auto storageClass = this->storageClass();
2683  if (storageClass == spirv::StorageClass::Generic ||
2684  storageClass == spirv::StorageClass::Function) {
2685  return emitOpError("storage class cannot be '")
2686  << stringifyStorageClass(storageClass) << "'";
2687  }
2688 
2689  if (auto init =
2690  (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2692  (*this)->getParentOp(), init.getAttr());
2693  // TODO: Currently only variable initialization with specialization
2694  // constants and other variables is supported. They could be normal
2695  // constants in the module scope as well.
2696  if (!initOp ||
2697  !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2698  return emitOpError("initializer must be result of a "
2699  "spirv.SpecConstant or spirv.GlobalVariable op");
2700  }
2701  }
2702 
2703  return success();
2704 }
2705 
2706 //===----------------------------------------------------------------------===//
2707 // spirv.GroupBroadcast
2708 //===----------------------------------------------------------------------===//
2709 
2711  spirv::Scope scope = getExecutionScope();
2712  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2713  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2714 
2715  if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
2716  if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2717  return emitOpError("localid is a vector and can be with only "
2718  " 2 or 3 components, actual number is ")
2719  << localIdTy.getNumElements();
2720 
2721  return success();
2722 }
2723 
2724 //===----------------------------------------------------------------------===//
2725 // spirv.GroupNonUniformBallotOp
2726 //===----------------------------------------------------------------------===//
2727 
2729  spirv::Scope scope = getExecutionScope();
2730  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2731  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2732 
2733  return success();
2734 }
2735 
2736 //===----------------------------------------------------------------------===//
2737 // spirv.GroupNonUniformBroadcast
2738 //===----------------------------------------------------------------------===//
2739 
2741  spirv::Scope scope = getExecutionScope();
2742  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2743  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2744 
2745  // SPIR-V spec: "Before version 1.5, Id must come from a
2746  // constant instruction.
2747  auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2748  if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2749  targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2750 
2751  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2752  auto *idOp = getId().getDefiningOp();
2753  if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2754  spirv::ReferenceOfOp>(idOp)) // for spec constant
2755  return emitOpError("id must be the result of a constant op");
2756  }
2757 
2758  return success();
2759 }
2760 
2761 //===----------------------------------------------------------------------===//
2762 // spirv.GroupNonUniformShuffle*
2763 //===----------------------------------------------------------------------===//
2764 
2765 template <typename OpTy>
2767  spirv::Scope scope = op.getExecutionScope();
2768  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2769  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2770 
2771  if (op.getOperands().back().getType().isSignedInteger())
2772  return op.emitOpError("second operand must be a singless/unsigned integer");
2773 
2774  return success();
2775 }
2776 
2778  return verifyGroupNonUniformShuffleOp(*this);
2779 }
2781  return verifyGroupNonUniformShuffleOp(*this);
2782 }
2784  return verifyGroupNonUniformShuffleOp(*this);
2785 }
2787  return verifyGroupNonUniformShuffleOp(*this);
2788 }
2789 
2790 //===----------------------------------------------------------------------===//
2791 // spirv.INTEL.SubgroupBlockRead
2792 //===----------------------------------------------------------------------===//
2793 
2794 ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
2795  OperationState &result) {
2796  // Parse the storage class specification
2797  spirv::StorageClass storageClass;
2799  Type elementType;
2800  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2801  parser.parseColon() || parser.parseType(elementType)) {
2802  return failure();
2803  }
2804 
2805  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2806  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2807  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2808 
2809  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
2810  return failure();
2811  }
2812 
2813  result.addTypes(elementType);
2814  return success();
2815 }
2816 
2818  printer << " " << getPtr() << " : " << getType();
2819 }
2820 
2822  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2823  return failure();
2824 
2825  return success();
2826 }
2827 
2828 //===----------------------------------------------------------------------===//
2829 // spirv.INTEL.SubgroupBlockWrite
2830 //===----------------------------------------------------------------------===//
2831 
2832 ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
2833  OperationState &result) {
2834  // Parse the storage class specification
2835  spirv::StorageClass storageClass;
2837  auto loc = parser.getCurrentLocation();
2838  Type elementType;
2839  if (parseEnumStrAttr(storageClass, parser) ||
2840  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2841  parser.parseType(elementType)) {
2842  return failure();
2843  }
2844 
2845  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2846  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2847  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2848 
2849  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2850  result.operands)) {
2851  return failure();
2852  }
2853  return success();
2854 }
2855 
2857  printer << " " << getPtr() << ", " << getValue() << " : "
2858  << getValue().getType();
2859 }
2860 
2862  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2863  return failure();
2864 
2865  return success();
2866 }
2867 
2868 //===----------------------------------------------------------------------===//
2869 // spirv.GroupNonUniformElectOp
2870 //===----------------------------------------------------------------------===//
2871 
2873  spirv::Scope scope = getExecutionScope();
2874  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2875  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2876 
2877  return success();
2878 }
2879 
2880 //===----------------------------------------------------------------------===//
2881 // spirv.GroupNonUniformFAddOp
2882 //===----------------------------------------------------------------------===//
2883 
2885  return verifyGroupNonUniformArithmeticOp(*this);
2886 }
2887 
2888 ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2889  OperationState &result) {
2890  return parseGroupNonUniformArithmeticOp(parser, result);
2891 }
2894 }
2895 
2896 //===----------------------------------------------------------------------===//
2897 // spirv.GroupNonUniformFMaxOp
2898 //===----------------------------------------------------------------------===//
2899 
2901  return verifyGroupNonUniformArithmeticOp(*this);
2902 }
2903 
2904 ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2905  OperationState &result) {
2906  return parseGroupNonUniformArithmeticOp(parser, result);
2907 }
2910 }
2911 
2912 //===----------------------------------------------------------------------===//
2913 // spirv.GroupNonUniformFMinOp
2914 //===----------------------------------------------------------------------===//
2915 
2917  return verifyGroupNonUniformArithmeticOp(*this);
2918 }
2919 
2920 ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2921  OperationState &result) {
2922  return parseGroupNonUniformArithmeticOp(parser, result);
2923 }
2926 }
2927 
2928 //===----------------------------------------------------------------------===//
2929 // spirv.GroupNonUniformFMulOp
2930 //===----------------------------------------------------------------------===//
2931 
2933  return verifyGroupNonUniformArithmeticOp(*this);
2934 }
2935 
2936 ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2937  OperationState &result) {
2938  return parseGroupNonUniformArithmeticOp(parser, result);
2939 }
2942 }
2943 
2944 //===----------------------------------------------------------------------===//
2945 // spirv.GroupNonUniformIAddOp
2946 //===----------------------------------------------------------------------===//
2947 
2949  return verifyGroupNonUniformArithmeticOp(*this);
2950 }
2951 
2952 ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2953  OperationState &result) {
2954  return parseGroupNonUniformArithmeticOp(parser, result);
2955 }
2958 }
2959 
2960 //===----------------------------------------------------------------------===//
2961 // spirv.GroupNonUniformIMulOp
2962 //===----------------------------------------------------------------------===//
2963 
2965  return verifyGroupNonUniformArithmeticOp(*this);
2966 }
2967 
2968 ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2969  OperationState &result) {
2970  return parseGroupNonUniformArithmeticOp(parser, result);
2971 }
2974 }
2975 
2976 //===----------------------------------------------------------------------===//
2977 // spirv.GroupNonUniformSMaxOp
2978 //===----------------------------------------------------------------------===//
2979 
2981  return verifyGroupNonUniformArithmeticOp(*this);
2982 }
2983 
2984 ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2985  OperationState &result) {
2986  return parseGroupNonUniformArithmeticOp(parser, result);
2987 }
2990 }
2991 
2992 //===----------------------------------------------------------------------===//
2993 // spirv.GroupNonUniformSMinOp
2994 //===----------------------------------------------------------------------===//
2995 
2997  return verifyGroupNonUniformArithmeticOp(*this);
2998 }
2999 
3000 ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
3001  OperationState &result) {
3002  return parseGroupNonUniformArithmeticOp(parser, result);
3003 }
3006 }
3007 
3008 //===----------------------------------------------------------------------===//
3009 // spirv.GroupNonUniformUMaxOp
3010 //===----------------------------------------------------------------------===//
3011 
3013  return verifyGroupNonUniformArithmeticOp(*this);
3014 }
3015 
3016 ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
3017  OperationState &result) {
3018  return parseGroupNonUniformArithmeticOp(parser, result);
3019 }
3022 }
3023 
3024 //===----------------------------------------------------------------------===//
3025 // spirv.GroupNonUniformUMinOp
3026 //===----------------------------------------------------------------------===//
3027 
3029  return verifyGroupNonUniformArithmeticOp(*this);
3030 }
3031 
3032 ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
3033  OperationState &result) {
3034  return parseGroupNonUniformArithmeticOp(parser, result);
3035 }
3038 }
3039 
3040 //===----------------------------------------------------------------------===//
3041 // spirv.IAddCarryOp
3042 //===----------------------------------------------------------------------===//
3043 
3046 }
3047 
3048 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
3049  OperationState &result) {
3051 }
3052 
3053 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
3054  ::printArithmeticExtendedBinaryOp(*this, printer);
3055 }
3056 
3057 //===----------------------------------------------------------------------===//
3058 // spirv.ISubBorrowOp
3059 //===----------------------------------------------------------------------===//
3060 
3063 }
3064 
3065 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
3066  OperationState &result) {
3068 }
3069 
3071  ::printArithmeticExtendedBinaryOp(*this, printer);
3072 }
3073 
3074 //===----------------------------------------------------------------------===//
3075 // spirv.SMulExtended
3076 //===----------------------------------------------------------------------===//
3077 
3080 }
3081 
3082 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
3083  OperationState &result) {
3085 }
3086 
3088  ::printArithmeticExtendedBinaryOp(*this, printer);
3089 }
3090 
3091 //===----------------------------------------------------------------------===//
3092 // spirv.UMulExtended
3093 //===----------------------------------------------------------------------===//
3094 
3097 }
3098 
3099 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
3100  OperationState &result) {
3102 }
3103 
3105  ::printArithmeticExtendedBinaryOp(*this, printer);
3106 }
3107 
3108 //===----------------------------------------------------------------------===//
3109 // spirv.LoadOp
3110 //===----------------------------------------------------------------------===//
3111 
3112 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
3113  Value basePtr, MemoryAccessAttr memoryAccess,
3114  IntegerAttr alignment) {
3115  auto ptrType = basePtr.getType().cast<spirv::PointerType>();
3116  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3117  alignment);
3118 }
3119 
3120 ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
3121  // Parse the storage class specification
3122  spirv::StorageClass storageClass;
3124  Type elementType;
3125  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
3126  parseMemoryAccessAttributes(parser, result) ||
3127  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
3128  parser.parseType(elementType)) {
3129  return failure();
3130  }
3131 
3132  auto ptrType = spirv::PointerType::get(elementType, storageClass);
3133  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
3134  return failure();
3135  }
3136 
3137  result.addTypes(elementType);
3138  return success();
3139 }
3140 
3141 void spirv::LoadOp::print(OpAsmPrinter &printer) {
3142  SmallVector<StringRef, 4> elidedAttrs;
3143  StringRef sc = stringifyStorageClass(
3144  getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3145  printer << " \"" << sc << "\" " << getPtr();
3146 
3147  printMemoryAccessAttribute(*this, printer, elidedAttrs);
3148 
3149  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3150  printer << " : " << getType();
3151 }
3152 
3154  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
3155  // type with fixed size; i.e., it cannot be, nor include, any
3156  // OpTypeRuntimeArray types."
3157  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
3158  return failure();
3159  }
3160  return verifyMemoryAccessAttribute(*this);
3161 }
3162 
3163 //===----------------------------------------------------------------------===//
3164 // spirv.mlir.loop
3165 //===----------------------------------------------------------------------===//
3166 
3167 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
3168  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
3170  state.addRegion();
3171 }
3172 
3173 ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
3174  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3175  result))
3176  return failure();
3177  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3178 }
3179 
3180 void spirv::LoopOp::print(OpAsmPrinter &printer) {
3181  auto control = getLoopControl();
3182  if (control != spirv::LoopControl::None)
3183  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
3184  printer << ' ';
3185  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3186  /*printBlockTerminators=*/true);
3187 }
3188 
3189 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
3190 /// given `dstBlock`.
3191 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
3192  // Check that there is only one op in the `srcBlock`.
3193  if (!llvm::hasSingleElement(srcBlock))
3194  return false;
3195 
3196  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
3197  return branchOp && branchOp.getSuccessor() == &dstBlock;
3198 }
3199 
3200 LogicalResult spirv::LoopOp::verifyRegions() {
3201  auto *op = getOperation();
3202 
3203  // We need to verify that the blocks follow the following layout:
3204  //
3205  // +-------------+
3206  // | entry block |
3207  // +-------------+
3208  // |
3209  // v
3210  // +-------------+
3211  // | loop header | <-----+
3212  // +-------------+ |
3213  // |
3214  // ... |
3215  // \ | / |
3216  // v |
3217  // +---------------+ |
3218  // | loop continue | -----+
3219  // +---------------+
3220  //
3221  // ...
3222  // \ | /
3223  // v
3224  // +-------------+
3225  // | merge block |
3226  // +-------------+
3227 
3228  auto &region = op->getRegion(0);
3229  // Allow empty region as a degenerated case, which can come from
3230  // optimizations.
3231  if (region.empty())
3232  return success();
3233 
3234  // The last block is the merge block.
3235  Block &merge = region.back();
3236  if (!isMergeBlock(merge))
3237  return emitOpError("last block must be the merge block with only one "
3238  "'spirv.mlir.merge' op");
3239 
3240  if (std::next(region.begin()) == region.end())
3241  return emitOpError(
3242  "must have an entry block branching to the loop header block");
3243  // The first block is the entry block.
3244  Block &entry = region.front();
3245 
3246  if (std::next(region.begin(), 2) == region.end())
3247  return emitOpError(
3248  "must have a loop header block branched from the entry block");
3249  // The second block is the loop header block.
3250  Block &header = *std::next(region.begin(), 1);
3251 
3252  if (!hasOneBranchOpTo(entry, header))
3253  return emitOpError(
3254  "entry block must only have one 'spirv.Branch' op to the second block");
3255 
3256  if (std::next(region.begin(), 3) == region.end())
3257  return emitOpError(
3258  "requires a loop continue block branching to the loop header block");
3259  // The second to last block is the loop continue block.
3260  Block &cont = *std::prev(region.end(), 2);
3261 
3262  // Make sure that we have a branch from the loop continue block to the loop
3263  // header block.
3264  if (llvm::none_of(
3265  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3266  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3267  return emitOpError("second to last block must be the loop continue "
3268  "block that branches to the loop header block");
3269 
3270  // Make sure that no other blocks (except the entry and loop continue block)
3271  // branches to the loop header block.
3272  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3273  std::prev(region.end(), 2))) {
3274  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3275  if (block.getSuccessor(i) == &header) {
3276  return emitOpError("can only have the entry and loop continue "
3277  "block branching to the loop header block");
3278  }
3279  }
3280  }
3281 
3282  return success();
3283 }
3284 
3285 Block *spirv::LoopOp::getEntryBlock() {
3286  assert(!getBody().empty() && "op region should not be empty!");
3287  return &getBody().front();
3288 }
3289 
3290 Block *spirv::LoopOp::getHeaderBlock() {
3291  assert(!getBody().empty() && "op region should not be empty!");
3292  // The second block is the loop header block.
3293  return &*std::next(getBody().begin());
3294 }
3295 
3296 Block *spirv::LoopOp::getContinueBlock() {
3297  assert(!getBody().empty() && "op region should not be empty!");
3298  // The second to last block is the loop continue block.
3299  return &*std::prev(getBody().end(), 2);
3300 }
3301 
3302 Block *spirv::LoopOp::getMergeBlock() {
3303  assert(!getBody().empty() && "op region should not be empty!");
3304  // The last block is the loop merge block.
3305  return &getBody().back();
3306 }
3307 
3308 void spirv::LoopOp::addEntryAndMergeBlock() {
3309  assert(getBody().empty() && "entry and merge block already exist");
3310  getBody().push_back(new Block());
3311  auto *mergeBlock = new Block();
3312  getBody().push_back(mergeBlock);
3313  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3314 
3315  // Add a spirv.mlir.merge op into the merge block.
3316  builder.create<spirv::MergeOp>(getLoc());
3317 }
3318 
3319 //===----------------------------------------------------------------------===//
3320 // spirv.MemoryBarrierOp
3321 //===----------------------------------------------------------------------===//
3322 
3324  return verifyMemorySemantics(getOperation(), getMemorySemantics());
3325 }
3326 
3327 //===----------------------------------------------------------------------===//
3328 // spirv.mlir.merge
3329 //===----------------------------------------------------------------------===//
3330 
3332  auto *parentOp = (*this)->getParentOp();
3333  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3334  return emitOpError(
3335  "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
3336 
3337  // TODO: This check should be done in `verifyRegions` of parent op.
3338  Block &parentLastBlock = (*this)->getParentRegion()->back();
3339  if (getOperation() != parentLastBlock.getTerminator())
3340  return emitOpError("can only be used in the last block of "
3341  "'spirv.mlir.selection' or 'spirv.mlir.loop'");
3342  return success();
3343 }
3344 
3345 //===----------------------------------------------------------------------===//
3346 // spirv.module
3347 //===----------------------------------------------------------------------===//
3348 
3349 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3350  Optional<StringRef> name) {
3351  OpBuilder::InsertionGuard guard(builder);
3352  builder.createBlock(state.addRegion());
3353  if (name) {
3355  builder.getStringAttr(*name));
3356  }
3357 }
3358 
3359 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3360  spirv::AddressingModel addressingModel,
3361  spirv::MemoryModel memoryModel,
3362  Optional<VerCapExtAttr> vceTriple,
3363  Optional<StringRef> name) {
3364  state.addAttribute(
3365  "addressing_model",
3366  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
3367  state.addAttribute("memory_model",
3368  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
3369  OpBuilder::InsertionGuard guard(builder);
3370  builder.createBlock(state.addRegion());
3371  if (vceTriple)
3372  state.addAttribute(getVCETripleAttrName(), *vceTriple);
3373  if (name)
3375  builder.getStringAttr(*name));
3376 }
3377 
3378 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
3379  OperationState &result) {
3380  Region *body = result.addRegion();
3381 
3382  // If the name is present, parse it.
3383  StringAttr nameAttr;
3384  (void)parser.parseOptionalSymbolName(
3385  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
3386 
3387  // Parse attributes
3388  spirv::AddressingModel addrModel;
3389  spirv::MemoryModel memoryModel;
3390  if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3391  result) ||
3392  ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3393  result))
3394  return failure();
3395 
3396  if (succeeded(parser.parseOptionalKeyword("requires"))) {
3397  spirv::VerCapExtAttr vceTriple;
3398  if (parser.parseAttribute(vceTriple,
3399  spirv::ModuleOp::getVCETripleAttrName(),
3400  result.attributes))
3401  return failure();
3402  }
3403 
3404  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3405  parser.parseRegion(*body, /*arguments=*/{}))
3406  return failure();
3407 
3408  // Make sure we have at least one block.
3409  if (body->empty())
3410  body->push_back(new Block());
3411 
3412  return success();
3413 }
3414 
3415 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3416  if (Optional<StringRef> name = getName()) {
3417  printer << ' ';
3418  printer.printSymbolName(*name);
3419  }
3420 
3421  SmallVector<StringRef, 2> elidedAttrs;
3422 
3423  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
3424  << spirv::stringifyMemoryModel(getMemoryModel());
3425  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3426  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3427  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3429 
3430  if (Optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
3431  printer << " requires " << *triple;
3432  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3433  }
3434 
3435  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3436  printer << ' ';
3437  printer.printRegion(getRegion());
3438 }
3439 
3440 LogicalResult spirv::ModuleOp::verifyRegions() {
3441  Dialect *dialect = (*this)->getDialect();
3443  entryPoints;
3444  mlir::SymbolTable table(*this);
3445 
3446  for (auto &op : *getBody()) {
3447  if (op.getDialect() != dialect)
3448  return op.emitError("'spirv.module' can only contain spirv.* ops");
3449 
3450  // For EntryPoint op, check that the function and execution model is not
3451  // duplicated in EntryPointOps. Also verify that the interface specified
3452  // comes from globalVariables here to make this check cheaper.
3453  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3454  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
3455  if (!funcOp) {
3456  return entryPointOp.emitError("function '")
3457  << entryPointOp.getFn() << "' not found in 'spirv.module'";
3458  }
3459  if (auto interface = entryPointOp.getInterface()) {
3460  for (Attribute varRef : interface) {
3461  auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3462  if (!varSymRef) {
3463  return entryPointOp.emitError(
3464  "expected symbol reference for interface "
3465  "specification instead of '")
3466  << varRef;
3467  }
3468  auto variableOp =
3469  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3470  if (!variableOp) {
3471  return entryPointOp.emitError("expected spirv.GlobalVariable "
3472  "symbol reference instead of'")
3473  << varSymRef << "'";
3474  }
3475  }
3476  }
3477 
3478  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3479  funcOp, entryPointOp.getExecutionModel());
3480  auto entryPtIt = entryPoints.find(key);
3481  if (entryPtIt != entryPoints.end()) {
3482  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3483  }
3484  entryPoints[key] = entryPointOp;
3485  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3486  if (funcOp.isExternal())
3487  return op.emitError("'spirv.module' cannot contain external functions");
3488 
3489  // TODO: move this check to spirv.func.
3490  for (auto &block : funcOp)
3491  for (auto &op : block) {
3492  if (op.getDialect() != dialect)
3493  return op.emitError(
3494  "functions in 'spirv.module' can only contain spirv.* ops");
3495  }
3496  }
3497  }
3498 
3499  return success();
3500 }
3501 
3502 //===----------------------------------------------------------------------===//
3503 // spirv.mlir.referenceof
3504 //===----------------------------------------------------------------------===//
3505 
3507  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3508  (*this)->getParentOp(), getSpecConstAttr());
3509  Type constType;
3510 
3511  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3512  if (specConstOp)
3513  constType = specConstOp.getDefaultValue().getType();
3514 
3515  auto specConstCompositeOp =
3516  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3517  if (specConstCompositeOp)
3518  constType = specConstCompositeOp.getType();
3519 
3520  if (!specConstOp && !specConstCompositeOp)
3521  return emitOpError(
3522  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
3523 
3524  if (getReference().getType() != constType)
3525  return emitOpError("result type mismatch with the referenced "
3526  "specialization constant's type");
3527 
3528  return success();
3529 }
3530 
3531 //===----------------------------------------------------------------------===//
3532 // spirv.Return
3533 //===----------------------------------------------------------------------===//
3534 
3536  // Verification is performed in spirv.func op.
3537  return success();
3538 }
3539 
3540 //===----------------------------------------------------------------------===//
3541 // spirv.ReturnValue
3542 //===----------------------------------------------------------------------===//
3543 
3545  // Verification is performed in spirv.func op.
3546  return success();
3547 }
3548 
3549 //===----------------------------------------------------------------------===//
3550 // spirv.Select
3551 //===----------------------------------------------------------------------===//
3552 
3554  if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
3555  auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
3556  if (!resultVectorTy) {
3557  return emitOpError("result expected to be of vector type when "
3558  "condition is of vector type");
3559  }
3560  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3561  return emitOpError("result should have the same number of elements as "
3562  "the condition when condition is of vector type");
3563  }
3564  }
3565  return success();
3566 }
3567 
3568 //===----------------------------------------------------------------------===//
3569 // spirv.mlir.selection
3570 //===----------------------------------------------------------------------===//
3571 
3572 ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3573  OperationState &result) {
3574  if (parseControlAttribute<spirv::SelectionControlAttr,
3575  spirv::SelectionControl>(parser, result))
3576  return failure();
3577  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3578 }
3579 
3580 void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3581  auto control = getSelectionControl();
3582  if (control != spirv::SelectionControl::None)
3583  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3584  printer << ' ';
3585  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3586  /*printBlockTerminators=*/true);
3587 }
3588 
3589 LogicalResult spirv::SelectionOp::verifyRegions() {
3590  auto *op = getOperation();
3591 
3592  // We need to verify that the blocks follow the following layout:
3593  //
3594  // +--------------+
3595  // | header block |
3596  // +--------------+
3597  // / | \
3598  // ...
3599  //
3600  //
3601  // +---------+ +---------+ +---------+
3602  // | case #0 | | case #1 | | case #2 | ...
3603  // +---------+ +---------+ +---------+
3604  //
3605  //
3606  // ...
3607  // \ | /
3608  // v
3609  // +-------------+
3610  // | merge block |
3611  // +-------------+
3612 
3613  auto &region = op->getRegion(0);
3614  // Allow empty region as a degenerated case, which can come from
3615  // optimizations.
3616  if (region.empty())
3617  return success();
3618 
3619  // The last block is the merge block.
3620  if (!isMergeBlock(region.back()))
3621  return emitOpError("last block must be the merge block with only one "
3622  "'spirv.mlir.merge' op");
3623 
3624  if (std::next(region.begin()) == region.end())
3625  return emitOpError("must have a selection header block");
3626 
3627  return success();
3628 }
3629 
3630 Block *spirv::SelectionOp::getHeaderBlock() {
3631  assert(!getBody().empty() && "op region should not be empty!");
3632  // The first block is the loop header block.
3633  return &getBody().front();
3634 }
3635 
3636 Block *spirv::SelectionOp::getMergeBlock() {
3637  assert(!getBody().empty() && "op region should not be empty!");
3638  // The last block is the loop merge block.
3639  return &getBody().back();
3640 }
3641 
3642 void spirv::SelectionOp::addMergeBlock() {
3643  assert(getBody().empty() && "entry and merge block already exist");
3644  auto *mergeBlock = new Block();
3645  getBody().push_back(mergeBlock);
3646  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3647 
3648  // Add a spirv.mlir.merge op into the merge block.
3649  builder.create<spirv::MergeOp>(getLoc());
3650 }
3651 
3652 spirv::SelectionOp spirv::SelectionOp::createIfThen(
3653  Location loc, Value condition,
3654  function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3655  auto selectionOp =
3656  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3657 
3658  selectionOp.addMergeBlock();
3659  Block *mergeBlock = selectionOp.getMergeBlock();
3660  Block *thenBlock = nullptr;
3661 
3662  // Build the "then" block.
3663  {
3664  OpBuilder::InsertionGuard guard(builder);
3665  thenBlock = builder.createBlock(mergeBlock);
3666  thenBody(builder);
3667  builder.create<spirv::BranchOp>(loc, mergeBlock);
3668  }
3669 
3670  // Build the header block.
3671  {
3672  OpBuilder::InsertionGuard guard(builder);
3673  builder.createBlock(thenBlock);
3674  builder.create<spirv::BranchConditionalOp>(
3675  loc, condition, thenBlock,
3676  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3677  /*falseArguments=*/ArrayRef<Value>());
3678  }
3679 
3680  return selectionOp;
3681 }
3682 
3683 //===----------------------------------------------------------------------===//
3684 // spirv.SpecConstant
3685 //===----------------------------------------------------------------------===//
3686 
3687 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3688  OperationState &result) {
3689  StringAttr nameAttr;
3690  Attribute valueAttr;
3691 
3692  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3693  result.attributes))
3694  return failure();
3695 
3696  // Parse optional spec_id.
3698  IntegerAttr specIdAttr;
3699  if (parser.parseLParen() ||
3700  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
3701  parser.parseRParen())
3702  return failure();
3703  }
3704 
3705  if (parser.parseEqual() ||
3706  parser.parseAttribute(valueAttr, kDefaultValueAttrName,
3707  result.attributes))
3708  return failure();
3709 
3710  return success();
3711 }
3712 
3714  printer << ' ';
3715  printer.printSymbolName(getSymName());
3716  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3717  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3718  printer << " = " << getDefaultValue();
3719 }
3720 
3722  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3723  if (specID.getValue().isNegative())
3724  return emitOpError("SpecId cannot be negative");
3725 
3726  auto value = getDefaultValue();
3727  if (value.isa<IntegerAttr, FloatAttr>()) {
3728  // Make sure bitwidth is allowed.
3729  if (!value.getType().isa<spirv::SPIRVType>())
3730  return emitOpError("default value bitwidth disallowed");
3731  return success();
3732  }
3733  return emitOpError(
3734  "default value can only be a bool, integer, or float scalar");
3735 }
3736 
3737 //===----------------------------------------------------------------------===//
3738 // spirv.StoreOp
3739 //===----------------------------------------------------------------------===//
3740 
3741 ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
3742  // Parse the storage class specification
3743  spirv::StorageClass storageClass;
3745  auto loc = parser.getCurrentLocation();
3746  Type elementType;
3747  if (parseEnumStrAttr(storageClass, parser) ||
3748  parser.parseOperandList(operandInfo, 2) ||
3749  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3750  parser.parseType(elementType)) {
3751  return failure();
3752  }
3753 
3754  auto ptrType = spirv::PointerType::get(elementType, storageClass);
3755  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3756  result.operands)) {
3757  return failure();
3758  }
3759  return success();
3760 }
3761 
3762 void spirv::StoreOp::print(OpAsmPrinter &printer) {
3763  SmallVector<StringRef, 4> elidedAttrs;
3764  StringRef sc = stringifyStorageClass(
3765  getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3766  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
3767 
3768  printMemoryAccessAttribute(*this, printer, elidedAttrs);
3769 
3770  printer << " : " << getValue().getType();
3771  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3772 }
3773 
3775  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3776  // OpTypePointer whose Type operand is the same as the type of Object."
3777  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
3778  return failure();
3779  return verifyMemoryAccessAttribute(*this);
3780 }
3781 
3782 //===----------------------------------------------------------------------===//
3783 // spirv.Unreachable
3784 //===----------------------------------------------------------------------===//
3785 
3787  auto *block = (*this)->getBlock();
3788  // Fast track: if this is in entry block, its invalid. Otherwise, if no
3789  // predecessors, it's valid.
3790  if (block->isEntryBlock())
3791  return emitOpError("cannot be used in reachable block");
3792  if (block->hasNoPredecessors())
3793  return success();
3794 
3795  // TODO: further verification needs to analyze reachability from
3796  // the entry block.
3797 
3798  return success();
3799 }
3800 
3801 //===----------------------------------------------------------------------===//
3802 // spirv.Variable
3803 //===----------------------------------------------------------------------===//
3804 
3805 ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3806  OperationState &result) {
3807  // Parse optional initializer
3809  if (succeeded(parser.parseOptionalKeyword("init"))) {
3810  initInfo = OpAsmParser::UnresolvedOperand();
3811  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3812  parser.parseRParen())
3813  return failure();
3814  }
3815 
3816  if (parseVariableDecorations(parser, result)) {
3817  return failure();
3818  }
3819 
3820  // Parse result pointer type
3821  Type type;
3822  if (parser.parseColon())
3823  return failure();
3824  auto loc = parser.getCurrentLocation();
3825  if (parser.parseType(type))
3826  return failure();
3827 
3828  auto ptrType = type.dyn_cast<spirv::PointerType>();
3829  if (!ptrType)
3830  return parser.emitError(loc, "expected spirv.ptr type");
3831  result.addTypes(ptrType);
3832 
3833  // Resolve the initializer operand
3834  if (initInfo) {
3835  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3836  result.operands))
3837  return failure();
3838  }
3839 
3840  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
3841  ptrType.getStorageClass());
3842  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3843 
3844  return success();
3845 }
3846 
3847 void spirv::VariableOp::print(OpAsmPrinter &printer) {
3848  SmallVector<StringRef, 4> elidedAttrs{
3849  spirv::attributeName<spirv::StorageClass>()};
3850  // Print optional initializer
3851  if (getNumOperands() != 0)
3852  printer << " init(" << getInitializer() << ")";
3853 
3854  printVariableDecorations(*this, printer, elidedAttrs);
3855  printer << " : " << getType();
3856 }
3857 
3859  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3860  // object. It cannot be Generic. It must be the same as the Storage Class
3861  // operand of the Result Type."
3862  if (getStorageClass() != spirv::StorageClass::Function) {
3863  return emitOpError(
3864  "can only be used to model function-level variables. Use "
3865  "spirv.GlobalVariable for module-level variables.");
3866  }
3867 
3868  auto pointerType = getPointer().getType().cast<spirv::PointerType>();
3869  if (getStorageClass() != pointerType.getStorageClass())
3870  return emitOpError(
3871  "storage class must match result pointer's storage class");
3872 
3873  if (getNumOperands() != 0) {
3874  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3875  // a global (module scope) OpVariable instruction".
3876  auto *initOp = getOperand(0).getDefiningOp();
3877  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3878  spirv::ReferenceOfOp, // for spec constant
3879  spirv::AddressOfOp>(initOp))
3880  return emitOpError("initializer must be the result of a "
3881  "constant or spirv.GlobalVariable op");
3882  }
3883 
3884  // TODO: generate these strings using ODS.
3885  auto *op = getOperation();
3886  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3887  stringifyDecoration(spirv::Decoration::DescriptorSet));
3888  auto bindingName = llvm::convertToSnakeFromCamelCase(
3889  stringifyDecoration(spirv::Decoration::Binding));
3890  auto builtInName = llvm::convertToSnakeFromCamelCase(
3891  stringifyDecoration(spirv::Decoration::BuiltIn));
3892 
3893  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3894  if (op->getAttr(attr))
3895  return emitOpError("cannot have '")
3896  << attr << "' attribute (only allowed in spirv.GlobalVariable)";
3897  }
3898 
3899  return success();
3900 }
3901 
3902 //===----------------------------------------------------------------------===//
3903 // spirv.VectorShuffle
3904 //===----------------------------------------------------------------------===//
3905 
3907  VectorType resultType = getType().cast<VectorType>();
3908 
3909  size_t numResultElements = resultType.getNumElements();
3910  if (numResultElements != getComponents().size())
3911  return emitOpError("result type element count (")
3912  << numResultElements
3913  << ") mismatch with the number of component selectors ("
3914  << getComponents().size() << ")";
3915 
3916  size_t totalSrcElements =
3917  getVector1().getType().cast<VectorType>().getNumElements() +
3918  getVector2().getType().cast<VectorType>().getNumElements();
3919 
3920  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
3921  uint32_t index = selector.getZExtValue();
3922  if (index >= totalSrcElements &&
3923  index != std::numeric_limits<uint32_t>().max())
3924  return emitOpError("component selector ")
3925  << index << " out of range: expected to be in [0, "
3926  << totalSrcElements << ") or 0xffffffff";
3927  }
3928  return success();
3929 }
3930 
3931 //===----------------------------------------------------------------------===//
3932 // spirv.NV.CooperativeMatrixLoad
3933 //===----------------------------------------------------------------------===//
3934 
3935 ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
3936  OperationState &result) {
3938  Type strideType = parser.getBuilder().getIntegerType(32);
3939  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3940  Type ptrType;
3941  Type elementType;
3942  if (parser.parseOperandList(operandInfo, 3) ||
3943  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3944  parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3945  return failure();
3946  }
3947  if (parser.resolveOperands(operandInfo,
3948  {ptrType, strideType, columnMajorType},
3949  parser.getNameLoc(), result.operands)) {
3950  return failure();
3951  }
3952 
3953  result.addTypes(elementType);
3954  return success();
3955 }
3956 
3958  printer << " " << getPointer() << ", " << getStride() << ", "
3959  << getColumnmajor();
3960  // Print optional memory access attribute.
3961  if (auto memAccess = getMemoryAccess())
3962  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3963  printer << " : " << getPointer().getType() << " as " << getType();
3964 }
3965 
3967  Type coopMatrix) {
3968  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3969  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3970  return op->emitError(
3971  "Pointer must point to a scalar or vector type but provided ")
3972  << pointeeType;
3973  spirv::StorageClass storage =
3974  pointer.cast<spirv::PointerType>().getStorageClass();
3975  if (storage != spirv::StorageClass::Workgroup &&
3976  storage != spirv::StorageClass::StorageBuffer &&
3977  storage != spirv::StorageClass::PhysicalStorageBuffer)
3978  return op->emitError(
3979  "Pointer storage class must be Workgroup, StorageBuffer or "
3980  "PhysicalStorageBufferEXT but provided ")
3981  << stringifyStorageClass(storage);
3982  return success();
3983 }
3984 
3986  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
3987  getResult().getType());
3988 }
3989 
3990 //===----------------------------------------------------------------------===//
3991 // spirv.NV.CooperativeMatrixStore
3992 //===----------------------------------------------------------------------===//
3993 
3994 ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
3995  OperationState &result) {
3997  Type strideType = parser.getBuilder().getIntegerType(32);
3998  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3999  Type ptrType;
4000  Type elementType;
4001  if (parser.parseOperandList(operandInfo, 4) ||
4002  parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
4003  parser.parseType(ptrType) || parser.parseComma() ||
4004  parser.parseType(elementType)) {
4005  return failure();
4006  }
4007  if (parser.resolveOperands(
4008  operandInfo, {ptrType, elementType, strideType, columnMajorType},
4009  parser.getNameLoc(), result.operands)) {
4010  return failure();
4011  }
4012 
4013  return success();
4014 }
4015 
4017  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
4018  << ", " << getColumnmajor();
4019  // Print optional memory access attribute.
4020  if (auto memAccess = getMemoryAccess())
4021  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
4022  printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
4023 }
4024 
4026  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
4027  getObject().getType());
4028 }
4029 
4030 //===----------------------------------------------------------------------===//
4031 // spirv.NV.CooperativeMatrixMulAdd
4032 //===----------------------------------------------------------------------===//
4033 
4034 static LogicalResult
4035 verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
4036  if (op.getC().getType() != op.getResult().getType())
4037  return op.emitOpError("result and third operand must have the same type");
4038  auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
4039  auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
4040  auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
4041  auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
4042  if (typeA.getRows() != typeR.getRows() ||
4043  typeA.getColumns() != typeB.getRows() ||
4044  typeB.getColumns() != typeR.getColumns())
4045  return op.emitOpError("matrix size must match");
4046  if (typeR.getScope() != typeA.getScope() ||
4047  typeR.getScope() != typeB.getScope() ||
4048  typeR.getScope() != typeC.getScope())
4049  return op.emitOpError("matrix scope must match");
4050  if (typeA.getElementType() != typeB.getElementType() ||
4051  typeR.getElementType() != typeC.getElementType())
4052  return op.emitOpError("matrix element type must match");
4053  return success();
4054 }
4055 
4057  return verifyCoopMatrixMulAdd(*this);
4058 }
4059 
4060 static LogicalResult
4062  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
4063  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
4064  return op->emitError(
4065  "Pointer must point to a scalar or vector type but provided ")
4066  << pointeeType;
4067  spirv::StorageClass storage =
4068  pointer.cast<spirv::PointerType>().getStorageClass();
4069  if (storage != spirv::StorageClass::Workgroup &&
4070  storage != spirv::StorageClass::CrossWorkgroup &&
4071  storage != spirv::StorageClass::UniformConstant &&
4072  storage != spirv::StorageClass::Generic)
4073  return op->emitError("Pointer storage class must be Workgroup or "
4074  "CrossWorkgroup but provided ")
4075  << stringifyStorageClass(storage);
4076  return success();
4077 }
4078 
4079 //===----------------------------------------------------------------------===//
4080 // spirv.INTEL.JointMatrixLoad
4081 //===----------------------------------------------------------------------===//
4082 
4084  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4085  getResult().getType());
4086 }
4087 
4088 //===----------------------------------------------------------------------===//
4089 // spirv.INTEL.JointMatrixStore
4090 //===----------------------------------------------------------------------===//
4091 
4093  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4094  getObject().getType());
4095 }
4096 
4097 //===----------------------------------------------------------------------===//
4098 // spirv.INTEL.JointMatrixMad
4099 //===----------------------------------------------------------------------===//
4100 
4101 static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
4102  if (op.getC().getType() != op.getResult().getType())
4103  return op.emitOpError("result and third operand must have the same type");
4104  auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
4105  auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
4106  auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
4107  auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
4108  if (typeA.getRows() != typeR.getRows() ||
4109  typeA.getColumns() != typeB.getRows() ||
4110  typeB.getColumns() != typeR.getColumns())
4111  return op.emitOpError("matrix size must match");
4112  if (typeR.getScope() != typeA.getScope() ||
4113  typeR.getScope() != typeB.getScope() ||
4114  typeR.getScope() != typeC.getScope())
4115  return op.emitOpError("matrix scope must match");
4116  if (typeA.getElementType() != typeB.getElementType() ||
4117  typeR.getElementType() != typeC.getElementType())
4118  return op.emitOpError("matrix element type must match");
4119  return success();
4120 }
4121 
4123  return verifyJointMatrixMad(*this);
4124 }
4125 
4126 //===----------------------------------------------------------------------===//
4127 // spirv.MatrixTimesScalar
4128 //===----------------------------------------------------------------------===//
4129 
4131  // We already checked that result and matrix are both of matrix type in the
4132  // auto-generated verify method.
4133 
4134  auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4135  auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4136 
4137  // Check that the scalar type is the same as the matrix element type.
4138  if (getScalar().getType() != inputMatrix.getElementType())
4139  return emitError("input matrix components' type and scaling value must "
4140  "have the same type");
4141 
4142  // Note that the next three checks could be done using the AllTypesMatch
4143  // trait in the Op definition file but it generates a vague error message.
4144 
4145  // Check that the input and result matrices have the same columns' count
4146  if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
4147  return emitError("input and result matrices must have the same "
4148  "number of columns");
4149 
4150  // Check that the input and result matrices' have the same rows count
4151  if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
4152  return emitError("input and result matrices' columns must have "
4153  "the same size");
4154 
4155  // Check that the input and result matrices' have the same component type
4156  if (inputMatrix.getElementType() != resultMatrix.getElementType())
4157  return emitError("input and result matrices' columns must have "
4158  "the same component type");
4159 
4160  return success();
4161 }
4162 
4163 //===----------------------------------------------------------------------===//
4164 // spirv.CopyMemory
4165 //===----------------------------------------------------------------------===//
4166 
4168  printer << ' ';
4169 
4170  StringRef targetStorageClass = stringifyStorageClass(
4171  getTarget().getType().cast<spirv::PointerType>().getStorageClass());
4172  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
4173 
4174  StringRef sourceStorageClass = stringifyStorageClass(
4175  getSource().getType().cast<spirv::PointerType>().getStorageClass());
4176  printer << " \"" << sourceStorageClass << "\" " << getSource();
4177 
4178  SmallVector<StringRef, 4> elidedAttrs;
4179  printMemoryAccessAttribute(*this, printer, elidedAttrs);
4180  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
4181  getSourceMemoryAccess(),
4182  getSourceAlignment());
4183 
4184  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4185 
4186  Type pointeeType =
4187  getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4188  printer << " : " << pointeeType;
4189 }
4190 
4191 ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
4192  OperationState &result) {
4193  spirv::StorageClass targetStorageClass;
4194  OpAsmParser::UnresolvedOperand targetPtrInfo;
4195 
4196  spirv::StorageClass sourceStorageClass;
4197  OpAsmParser::UnresolvedOperand sourcePtrInfo;
4198 
4199  Type elementType;
4200 
4201  if (parseEnumStrAttr(targetStorageClass, parser) ||
4202  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
4203  parseEnumStrAttr(sourceStorageClass, parser) ||
4204  parser.parseOperand(sourcePtrInfo) ||
4205  parseMemoryAccessAttributes(parser, result)) {
4206  return failure();
4207  }
4208 
4209  if (!parser.parseOptionalComma()) {
4210  // Parse 2nd memory access attributes.
4211  if (parseSourceMemoryAccessAttributes(parser, result)) {
4212  return failure();
4213  }
4214  }
4215 
4216  if (parser.parseColon() || parser.parseType(elementType))
4217  return failure();
4218 
4219  if (parser.parseOptionalAttrDict(result.attributes))
4220  return failure();
4221 
4222  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
4223  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
4224 
4225  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
4226  parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
4227  return failure();
4228  }
4229 
4230  return success();
4231 }
4232 
4234  Type targetType =
4235  getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4236 
4237  Type sourceType =
4238  getSource().getType().cast<spirv::PointerType>().getPointeeType();
4239 
4240  if (targetType != sourceType)
4241  return emitOpError("both operands must be pointers to the same type");
4242 
4243  if (failed(verifyMemoryAccessAttribute(*this)))
4244  return failure();
4245 
4246  // TODO - According to the spec:
4247  //
4248  // If two masks are present, the first applies to Target and cannot include
4249  // MakePointerVisible, and the second applies to Source and cannot include
4250  // MakePointerAvailable.
4251  //
4252  // Add such verification here.
4253 
4254  return verifySourceMemoryAccessAttribute(*this);
4255 }
4256 
4257 //===----------------------------------------------------------------------===//
4258 // spirv.Transpose
4259 //===----------------------------------------------------------------------===//
4260 
4262  auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4263  auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4264 
4265  // Verify that the input and output matrices have correct shapes.
4266  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
4267  return emitError("input matrix rows count must be equal to "
4268  "output matrix columns count");
4269 
4270  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
4271  return emitError("input matrix columns count must be equal to "
4272  "output matrix rows count");
4273 
4274  // Verify that the input and output matrices have the same component type
4275  if (inputMatrix.getElementType() != resultMatrix.getElementType())
4276  return emitError("input and output matrices must have the same "
4277  "component type");
4278 
4279  return success();
4280 }
4281 
4282 //===----------------------------------------------------------------------===//
4283 // spirv.MatrixTimesMatrix
4284 //===----------------------------------------------------------------------===//
4285 
4287  auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
4288  auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
4289  auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4290 
4291  // left matrix columns' count and right matrix rows' count must be equal
4292  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4293  return emitError("left matrix columns' count must be equal to "
4294  "the right matrix rows' count");
4295 
4296  // right and result matrices columns' count must be the same
4297  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4298  return emitError(
4299  "right and result matrices must have equal columns' count");
4300 
4301  // right and result matrices component type must be the same
4302  if (rightMatrix.getElementType() != resultMatrix.getElementType())
4303  return emitError("right and result matrices' component type must"
4304  " be the same");
4305 
4306  // left and result matrices component type must be the same
4307  if (leftMatrix.getElementType() != resultMatrix.getElementType())
4308  return emitError("left and result matrices' component type"
4309  " must be the same");
4310 
4311  // left and result matrices rows count must be the same
4312  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4313  return emitError("left and result matrices must have equal rows' count");
4314 
4315  return success();
4316 }
4317 
4318 //===----------------------------------------------------------------------===//
4319 // spirv.SpecConstantComposite
4320 //===----------------------------------------------------------------------===//
4321 
4322 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4323  OperationState &result) {
4324 
4325  StringAttr compositeName;
4326  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4327  result.attributes))
4328  return failure();
4329 
4330  if (parser.parseLParen())
4331  return failure();
4332 
4333  SmallVector<Attribute, 4> constituents;
4334 
4335  do {
4336  // The name of the constituent attribute isn't important
4337  const char *attrName = "spec_const";
4338  FlatSymbolRefAttr specConstRef;
4339  NamedAttrList attrs;
4340 
4341  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4342  return failure();
4343 
4344  constituents.push_back(specConstRef);
4345  } while (!parser.parseOptionalComma());
4346 
4347  if (parser.parseRParen())
4348  return failure();
4349 
4351  parser.getBuilder().getArrayAttr(constituents));
4352 
4353  Type type;
4354  if (parser.parseColonType(type))
4355  return failure();
4356 
4357  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
4358 
4359  return success();
4360 }
4361 
4363  printer << " ";
4364  printer.printSymbolName(getSymName());
4365  printer << " (";
4366  auto constituents = this->getConstituents().getValue();
4367 
4368  if (!constituents.empty())
4369  llvm::interleaveComma(constituents, printer);
4370 
4371  printer << ") : " << getType();
4372 }
4373 
4375  auto cType = getType().dyn_cast<spirv::CompositeType>();
4376  auto constituents = this->getConstituents().getValue();
4377 
4378  if (!cType)
4379  return emitError("result type must be a composite type, but provided ")
4380  << getType();
4381 
4382  if (cType.isa<spirv::CooperativeMatrixNVType>())
4383  return emitError("unsupported composite type ") << cType;
4384  if (cType.isa<spirv::JointMatrixINTELType>())
4385  return emitError("unsupported composite type ") << cType;
4386  if (constituents.size() != cType.getNumElements())
4387  return emitError("has incorrect number of operands: expected ")
4388  << cType.getNumElements() << ", but provided "
4389  << constituents.size();
4390 
4391  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4392  auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4393 
4394  auto constituentSpecConstOp =
4395  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4396  (*this)->getParentOp(), constituent.getAttr()));
4397 
4398  if (constituentSpecConstOp.getDefaultValue().getType() !=
4399  cType.getElementType(index))
4400  return emitError("has incorrect types of operands: expected ")
4401  << cType.getElementType(index) << ", but provided "
4402  << constituentSpecConstOp.getDefaultValue().getType();
4403  }
4404 
4405  return success();
4406 }
4407 
4408 //===----------------------------------------------------------------------===//
4409 // spirv.SpecConstantOperation
4410 //===----------------------------------------------------------------------===//
4411 
4412 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4413  OperationState &result) {
4414  Region *body = result.addRegion();
4415 
4416  if (parser.parseKeyword("wraps"))
4417  return failure();
4418 
4419  body->push_back(new Block);
4420  Block &block = body->back();
4421  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4422 
4423  if (!wrappedOp)
4424  return failure();
4425 
4426  OpBuilder builder(parser.getContext());
4427  builder.setInsertionPointToEnd(&block);
4428  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4429  result.location = wrappedOp->getLoc();
4430 
4431  result.addTypes(wrappedOp->getResult(0).getType());
4432 
4433  if (parser.parseOptionalAttrDict(result.attributes))
4434  return failure();
4435 
4436  return success();
4437 }
4438 
4440  printer << " wraps ";
4441  printer.printGenericOp(&getBody().front().front());
4442 }
4443 
4444 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4445  Block &block = getRegion().getBlocks().front();
4446 
4447  if (block.getOperations().size() != 2)
4448  return emitOpError("expected exactly 2 nested ops");
4449 
4450  Operation &enclosedOp = block.getOperations().front();
4451 
4453  return emitOpError("invalid enclosed op");
4454 
4455  for (auto operand : enclosedOp.getOperands())
4456  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4457  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4458  return emitOpError(
4459  "invalid operand, must be defined by a constant operation");
4460 
4461  return success();
4462 }
4463 
4464 //===----------------------------------------------------------------------===//
4465 // spirv.GL.FrexpStruct
4466 //===----------------------------------------------------------------------===//
4467 
4469  spirv::StructType structTy =
4470  getResult().getType().dyn_cast<spirv::StructType>();
4471 
4472  if (structTy.getNumElements() != 2)
4473  return emitError("result type must be a struct type with two memebers");
4474 
4475  Type significandTy = structTy.getElementType(0);
4476  Type exponentTy = structTy.getElementType(1);
4477  VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4478  IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4479 
4480  Type operandTy = getOperand().getType();
4481  VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4482  FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4483 
4484  if (significandTy != operandTy)
4485  return emitError("member zero of the resulting struct type must be the "
4486  "same type as the operand");
4487 
4488  if (exponentVecTy) {
4489  IntegerType componentIntTy =
4490  exponentVecTy.getElementType().dyn_cast<IntegerType>();
4491  if (!componentIntTy || componentIntTy.getWidth() != 32)
4492  return emitError("member one of the resulting struct type must"
4493  "be a scalar or vector of 32 bit integer type");
4494  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4495  return emitError("member one of the resulting struct type "
4496  "must be a scalar or vector of 32 bit integer type");
4497  }
4498 
4499  // Check that the two member types have the same number of components
4500  if (operandVecTy && exponentVecTy &&
4501  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4502  return success();
4503 
4504  if (operandFTy && exponentIntTy)
4505  return success();
4506 
4507  return emitError("member one of the resulting struct type must have the same "
4508  "number of components as the operand type");
4509 }
4510 
4511 //===----------------------------------------------------------------------===//
4512 // spirv.GL.Ldexp
4513 //===----------------------------------------------------------------------===//
4514 
4516  Type significandType = getX().getType();
4517  Type exponentType = getExp().getType();
4518 
4519  if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4520  return emitOpError("operands must both be scalars or vectors");
4521 
4522  auto getNumElements = [](Type type) -> unsigned {
4523  if (auto vectorType = type.dyn_cast<VectorType>())
4524  return vectorType.getNumElements();
4525  return 1;
4526  };
4527 
4528  if (getNumElements(significandType) != getNumElements(exponentType))
4529  return emitOpError("operands must have the same number of elements");
4530 
4531  return success();
4532 }
4533 
4534 //===----------------------------------------------------------------------===//
4535 // spirv.ImageDrefGather
4536 //===----------------------------------------------------------------------===//
4537 
4539  VectorType resultType = getResult().getType().cast<VectorType>();
4540  auto sampledImageType =
4541  getSampledimage().getType().cast<spirv::SampledImageType>();
4542  auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4543 
4544  if (resultType.getNumElements() != 4)
4545  return emitOpError("result type must be a vector of four components");
4546 
4547  Type elementType = resultType.getElementType();
4548  Type sampledElementType = imageType.getElementType();
4549  if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4550  return emitOpError(
4551  "the component type of result must be the same as sampled type of the "
4552  "underlying image type");
4553 
4554  spirv::Dim imageDim = imageType.getDim();
4555  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4556 
4557  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4558  imageDim != spirv::Dim::Rect)
4559  return emitOpError(
4560  "the Dim operand of the underlying image type must be 2D, Cube, or "
4561  "Rect");
4562 
4563  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4564  return emitOpError("the MS operand of the underlying image type must be 0");
4565 
4566  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
4567  auto operandArguments = getOperandArguments();
4568 
4569  return verifyImageOperands(*this, attr, operandArguments);
4570 }
4571 
4572 //===----------------------------------------------------------------------===//
4573 // spirv.ShiftLeftLogicalOp
4574 //===----------------------------------------------------------------------===//
4575 
4577  return verifyShiftOp(*this);
4578 }
4579 
4580 //===----------------------------------------------------------------------===//
4581 // spirv.ShiftRightArithmeticOp
4582 //===----------------------------------------------------------------------===//
4583 
4585  return verifyShiftOp(*this);
4586 }
4587 
4588 //===----------------------------------------------------------------------===//
4589 // spirv.ShiftRightLogicalOp
4590 //===----------------------------------------------------------------------===//
4591 
4593  return verifyShiftOp(*this);
4594 }
4595 
4596 //===----------------------------------------------------------------------===//
4597 // spirv.ImageQuerySize
4598 //===----------------------------------------------------------------------===//
4599 
4601  spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
4602  Type resultType = getResult().getType();
4603 
4604  spirv::Dim dim = imageType.getDim();
4605  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4606  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4607  switch (dim) {
4608  case spirv::Dim::Dim1D:
4609  case spirv::Dim::Dim2D:
4610  case spirv::Dim::Dim3D:
4611  case spirv::Dim::Cube:
4612  if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4613  samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4614  samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4615  return emitError(
4616  "if Dim is 1D, 2D, 3D, or Cube, "
4617  "it must also have either an MS of 1 or a Sampled of 0 or 2");
4618  break;
4619  case spirv::Dim::Buffer:
4620  case spirv::Dim::Rect:
4621  break;
4622  default:
4623  return emitError("the Dim operand of the image type must "
4624  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4625  }
4626 
4627  unsigned componentNumber = 0;
4628  switch (dim) {
4629  case spirv::Dim::Dim1D:
4630  case spirv::Dim::Buffer:
4631  componentNumber = 1;
4632  break;
4633  case spirv::Dim::Dim2D:
4634  case spirv::Dim::Cube:
4635  case spirv::Dim::Rect:
4636  componentNumber = 2;
4637  break;
4638  case spirv::Dim::Dim3D:
4639  componentNumber = 3;
4640  break;
4641  default:
4642  break;
4643  }
4644 
4645  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4646  componentNumber += 1;
4647 
4648  unsigned resultComponentNumber = 1;
4649  if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4650  resultComponentNumber = resultVectorType.getNumElements();
4651 
4652  if (componentNumber != resultComponentNumber)
4653  return emitError("expected the result to have ")
4654  << componentNumber << " component(s), but found "
4655  << resultComponentNumber << " component(s)";
4656 
4657  return success();
4658 }
4659 
4660 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4661  OpAsmParser &parser,
4662  OperationState &state) {
4665  Type type;
4666  auto loc = parser.getCurrentLocation();
4667  SmallVector<Type, 4> indicesTypes;
4668 
4669  if (parser.parseOperand(ptrInfo) ||
4670  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4671  parser.parseColonType(type) ||
4672  parser.resolveOperand(ptrInfo, type, state.operands))
4673  return failure();
4674 
4675  // Check that the provided indices list is not empty before parsing their
4676  // type list.
4677  if (indicesInfo.empty())
4678  return emitError(state.location) << opName << " expected element";
4679 
4680  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4681  return failure();
4682 
4683  // Check that the indices types list is not empty and that it has a one-to-one
4684  // mapping to the provided indices.
4685  if (indicesTypes.size() != indicesInfo.size())
4686  return emitError(state.location)
4687  << opName
4688  << " indices types' count must be equal to indices info count";
4689 
4690  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4691  return failure();
4692 
4693  auto resultType = getElementPtrType(
4694  type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4695  if (!resultType)
4696  return failure();
4697 
4698  state.addTypes(resultType);
4699  return success();
4700 }
4701 
4702 template <typename Op>
4703 static auto concatElemAndIndices(Op op) {
4704  SmallVector<Value> ret(op.getIndices().size() + 1);
4705  ret[0] = op.getElement();
4706  llvm::copy(op.getIndices(), ret.begin() + 1);
4707  return ret;
4708 }
4709 
4710 //===----------------------------------------------------------------------===//
4711 // spirv.InBoundsPtrAccessChainOp
4712 //===----------------------------------------------------------------------===//
4713 
4714 void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4715  OperationState &state,
4716  Value basePtr, Value element,
4717  ValueRange indices) {
4718  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4719  assert(type && "Unable to deduce return type based on basePtr and indices");
4720  build(builder, state, type, basePtr, element, indices);
4721 }
4722 
4723 ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4724  OperationState &result) {
4726  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4727 }
4728 
4730  printAccessChain(*this, concatElemAndIndices(*this), printer);
4731 }
4732 
4734  return verifyAccessChain(*this, getIndices());
4735 }
4736 
4737 //===----------------------------------------------------------------------===//
4738 // spirv.PtrAccessChainOp
4739 //===----------------------------------------------------------------------===//
4740 
4741 void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4742  Value basePtr, Value element,
4743  ValueRange indices) {
4744  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4745  assert(type && "Unable to deduce return type based on basePtr and indices");
4746  build(builder, state, type, basePtr, element, indices);
4747 }
4748 
4749 ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4750  OperationState &result) {
4751  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4752  parser, result);
4753 }
4754 
4756  printAccessChain(*this, concatElemAndIndices(*this), printer);
4757 }
4758 
4760  return verifyAccessChain(*this, getIndices());
4761 }
4762 
4763 //===----------------------------------------------------------------------===//
4764 // spirv.VectorTimesScalarOp
4765 //===----------------------------------------------------------------------===//
4766 
4768  if (getVector().getType() != getType())
4769  return emitOpError("vector operand and result type mismatch");
4770  auto scalarType = getType().cast<VectorType>().getElementType();
4771  if (getScalar().getType() != scalarType)
4772  return emitOpError("scalar operand and result element type match");
4773  return success();
4774 }
4775 
4776 //===----------------------------------------------------------------------===//
4777 // Group ops
4778 //===----------------------------------------------------------------------===//
4779 
4780 template <typename Op>
4782  spirv::Scope scope = op.getExecutionScope();
4783  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
4784  return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
4785 
4786  return success();
4787 }
4788 
4790 
4792 
4794 
4796 
4798 
4800 
4802 
4804 
4806 
4808 
4809 // TableGen'erated operation interfaces for querying versions, extensions, and
4810 // capabilities.
4811 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4812 
4813 // TablenGen'erated operation definitions.
4814 #define GET_OP_CLASSES
4815 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4816 
4817 namespace mlir {
4818 namespace spirv {
4819 // TableGen'erated operation availability interface implementations.
4820 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4821 } // namespace spirv
4822 } // namespace mlir
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static constexpr const bool value
@ None
Operation::operand_range getIndices(Operation *op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:783
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:901
constexpr char kValuesAttrName[]
Definition: SPIRVOps.cpp:63
static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix)
Definition: SPIRVOps.cpp:4061
static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:609
constexpr char kClusterSize[]
Definition: SPIRVOps.cpp:47
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition: SPIRVOps.cpp:956
constexpr char kValueAttrName[]
Definition: SPIRVOps.cpp:62
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:1985
constexpr char kEqualSemanticsAttrName[]
Definition: SPIRVOps.cpp:51
static LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:558
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition: SPIRVOps.cpp:420
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue)
Definition: SPIRVOps.cpp:821
constexpr char kExecutionScopeAttrName[]
Definition: SPIRVOps.cpp:50
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:646
static unsigned getBitWidth(Type type)
Definition: SPIRVOps.cpp:674
constexpr char kTypeAttrName[]
Definition: SPIRVOps.cpp:60
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
Definition: SPIRVOps.cpp:763
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:70
constexpr char kUnequalSemanticsAttrName[]
Definition: SPIRVOps.cpp:61
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: SPIRVOps.cpp:1150
constexpr char kIndicesAttrName[]
Definition: SPIRVOps.cpp:54
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1198
static ParseResult parseImageOperands(OpAsmParser &parser, spirv::ImageOperandsAttr &attr)
Definition: SPIRVOps.cpp:369
static StringRef stringifyTypeName()
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:769
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
Definition: SPIRVOps.cpp:2766
static LogicalResult verifyAtomicUpdateOp(Operation *op)
Definition: SPIRVOps.cpp:877
constexpr char kBranchWeightAttrName[]
Definition: SPIRVOps.cpp:45
constexpr char kCompositeSpecConstituentsName[]
Definition: SPIRVOps.cpp:64
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:283
constexpr char kInterfaceAttrName[]
Definition: SPIRVOps.cpp:56
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:582
constexpr char kSourceAlignmentAttrName[]
Definition: SPIRVOps.cpp:44
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, Optional< spirv::MemoryAccess > memoryAccessAtrrValue=None, Optional< uint32_t > alignmentAttrValue=None)
Definition: SPIRVOps.cpp:311
constexpr char kCallee[]
Definition: SPIRVOps.cpp:46
constexpr char kMemoryScopeAttrName[]
Definition: SPIRVOps.cpp:57
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix)
Definition: SPIRVOps.cpp:3966
static bool isDirectInModuleLikeOp(Operation *op)
Returns true if the given op is an module-like op that maintains a symbol table.
Definition: SPIRVOps.cpp:132
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:996
constexpr char kGroupOperationAttrName[]
Definition: SPIRVOps.cpp:53
static Type getUnaryOpResultType(Type operandType)
Result of a logical op must be a scalar or vector of boolean type.
Definition: SPIRVOps.cpp:988
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access attributes attached to a memory access operand/pointer.
Definition: SPIRVOps.cpp:252
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:1032
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op)
Definition: SPIRVOps.cpp:4101
constexpr char kMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:41
constexpr char kControl[]
Definition: SPIRVOps.cpp:48
constexpr char kDefaultValueAttrName[]
Definition: SPIRVOps.cpp:49
static LogicalResult verifyGroupOp(Op op)
Definition: SPIRVOps.cpp:4781
static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, Value value)
Definition: SPIRVOps.cpp:1018
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:136
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:852
StringRef stringifyTypeName< FloatType >()
Definition: SPIRVOps.cpp:871
static auto concatElemAndIndices(Op op)
Definition: SPIRVOps.cpp:4703
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:597
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:937
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1140
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
Definition: SPIRVOps.cpp:3191
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, Optional< spirv::MemoryAccess > memoryAccessAtrrValue=None, Optional< uint32_t > alignmentAttrValue=None)
Definition: SPIRVOps.cpp:341
static bool isNestedInFunctionOpInterface(Operation *op)
Returns true if the given op is a function-like op or nested in a function-like op without a module-l...
Definition: SPIRVOps.cpp:120
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, Value lhs, Value rhs)
Definition: SPIRVOps.cpp:1006
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:99
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, spirv::ImageOperandsAttr attr)
Definition: SPIRVOps.cpp:384
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:4660
constexpr char kSourceMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:42
constexpr char kSpecIdAttrName[]
Definition: SPIRVOps.cpp:59
constexpr char kAlignmentAttrName[]
Definition: SPIRVOps.cpp:43
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:518
static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef< Ty > enumValues, function_ref< StringRef(Ty)> stringifyFn)
Definition: SPIRVOps.cpp:157
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1205
StringRef stringifyTypeName< IntegerType >()
Definition: SPIRVOps.cpp:866
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:393
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:474
constexpr char kSemanticsAttrName[]
Definition: SPIRVOps.cpp:58
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp)
Definition: SPIRVOps.cpp:1238
static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
Definition: SPIRVOps.cpp:174
static LogicalResult verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op)
Definition: SPIRVOps.cpp:4035
constexpr char kInitializerAttrName[]
Definition: SPIRVOps.cpp:55
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
Definition: SPIRVOps.cpp:229
constexpr char kFnNameAttrName[]
Definition: SPIRVOps.cpp:52
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:807
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1252
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
U cast() const
Definition: Attributes.h:137
bool isa() const
Casting utility functions.
Definition: Attributes.h:117
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumSuccessors()
Definition: Block.cpp:238
bool empty()
Definition: Block.h:137
Operation & back()
Definition: Block.h:141
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
OpListType & getOperations()
Definition: Block.h:126
Operation & front()
Definition: Block.h:142
iterator end()
Definition: Block.h:133
iterator begin()
Definition: Block.h:132
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:190
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:257
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:235
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:81
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
NoneType getNoneType()
Definition: Builders.cpp:89
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:101
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
MLIRContext * getContext() const
Definition: Builders.h:54
IntegerType getI1Type()
Definition: Builders.cpp:58
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:247
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:288
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:236
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:633
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:639
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:108
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:341
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
type_range getType() const
Definition: ValueRange.cpp:30
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:151
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:371
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_type_range getResultTypes()
Definition: Operation.h:345
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class models how operands are forwarded to block arguments in control flow.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:23
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:58
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:121
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
U dyn_cast() const
Definition: Types.h:270
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:89
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getTypes() const
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:51
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:125
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:242
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:244
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:240
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:422
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:430
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:426
Scope getScope() const
Return the scope of the joint matrix.
Definition: SPIRVTypes.cpp:306
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:310
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:308
unsigned getNumColumns() const
Returns the number of columns.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:480
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:482
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:476
SPIR-V struct type.
Definition: SPIRVTypes.h:281
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.cpp:24
void printFunctionSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
Type getFunctionType(Builder &builder, ArrayRef< OpAsmParser::Argument > argAttrs, ArrayRef< Type > resultTypes)
Get a function type corresponding to an array of arguments (which have types) and a set of result typ...
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
AddressingModel getAddressingModel(TargetEnvAttr targetAttr)
Returns addressing model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.