MLIR  14.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/bit.h"
33 
34 using namespace mlir;
35 
36 // TODO: generate these strings using ODS.
37 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
38 static constexpr const char kSourceMemoryAccessAttrName[] =
39  "source_memory_access";
40 static constexpr const char kAlignmentAttrName[] = "alignment";
41 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
42 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
43 static constexpr const char kCallee[] = "callee";
44 static constexpr const char kClusterSize[] = "cluster_size";
45 static constexpr const char kControl[] = "control";
46 static constexpr const char kDefaultValueAttrName[] = "default_value";
47 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
48 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
49 static constexpr const char kFnNameAttrName[] = "fn";
50 static constexpr const char kGroupOperationAttrName[] = "group_operation";
51 static constexpr const char kIndicesAttrName[] = "indices";
52 static constexpr const char kInitializerAttrName[] = "initializer";
53 static constexpr const char kInterfaceAttrName[] = "interface";
54 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
55 static constexpr const char kSemanticsAttrName[] = "semantics";
56 static constexpr const char kSpecIdAttrName[] = "spec_id";
57 static constexpr const char kTypeAttrName[] = "type";
58 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
59 static constexpr const char kValueAttrName[] = "value";
60 static constexpr const char kValuesAttrName[] = "values";
61 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
62 
63 //===----------------------------------------------------------------------===//
64 // Common utility functions
65 //===----------------------------------------------------------------------===//
66 
67 /// Returns true if the given op is a function-like op or nested in a
68 /// function-like op without a module-like op in the middle.
70  if (!op)
71  return false;
72  if (op->hasTrait<OpTrait::SymbolTable>())
73  return false;
74  if (isa<FunctionOpInterface>(op))
75  return true;
77 }
78 
79 /// Returns true if the given op is an module-like op that maintains a symbol
80 /// table.
82  return op && op->hasTrait<OpTrait::SymbolTable>();
83 }
84 
86  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
87  if (!constOp) {
88  return failure();
89  }
90  auto valueAttr = constOp.value();
91  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
92  if (!integerValueAttr) {
93  return failure();
94  }
95 
96  if (integerValueAttr.getType().isSignlessInteger())
97  value = integerValueAttr.getInt();
98  else
99  value = integerValueAttr.getSInt();
100 
101  return success();
102 }
103 
104 template <typename Ty>
105 static ArrayAttr
107  function_ref<StringRef(Ty)> stringifyFn) {
108  if (enumValues.empty()) {
109  return nullptr;
110  }
111  SmallVector<StringRef, 1> enumValStrs;
112  enumValStrs.reserve(enumValues.size());
113  for (auto val : enumValues) {
114  enumValStrs.emplace_back(stringifyFn(val));
115  }
116  return builder.getStrArrayAttr(enumValStrs);
117 }
118 
119 /// Parses the next string attribute in `parser` as an enumerant of the given
120 /// `EnumClass`.
121 template <typename EnumClass>
122 static ParseResult
124  StringRef attrName = spirv::attributeName<EnumClass>()) {
125  Attribute attrVal;
126  NamedAttrList attr;
127  auto loc = parser.getCurrentLocation();
128  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
129  attrName, attr)) {
130  return failure();
131  }
132  if (!attrVal.isa<StringAttr>()) {
133  return parser.emitError(loc, "expected ")
134  << attrName << " attribute specified as string";
135  }
136  auto attrOptional =
137  spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
138  if (!attrOptional) {
139  return parser.emitError(loc, "invalid ")
140  << attrName << " attribute specification: " << attrVal;
141  }
142  value = attrOptional.getValue();
143  return success();
144 }
145 
146 /// Parses the next string attribute in `parser` as an enumerant of the given
147 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
148 /// attribute with the enum class's name as attribute name.
149 template <typename EnumClass>
150 static ParseResult
151 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
152  StringRef attrName = spirv::attributeName<EnumClass>()) {
153  if (parseEnumStrAttr(value, parser)) {
154  return failure();
155  }
156  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
157  llvm::bit_cast<int32_t>(value)));
158  return success();
159 }
160 
161 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
162 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
163 /// the enum class's name as attribute name.
164 template <typename EnumClass>
165 static ParseResult
166 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
167  OperationState &state,
168  StringRef attrName = spirv::attributeName<EnumClass>()) {
169  if (parseEnumKeywordAttr(value, parser)) {
170  return failure();
171  }
172  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
173  llvm::bit_cast<int32_t>(value)));
174  return success();
175 }
176 
177 /// Parses Function, Selection and Loop control attributes. If no control is
178 /// specified, "None" is used as a default.
179 template <typename EnumClass>
180 static ParseResult
182  StringRef attrName = spirv::attributeName<EnumClass>()) {
183  if (succeeded(parser.parseOptionalKeyword(kControl))) {
184  EnumClass control;
185  if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
186  parser.parseRParen())
187  return failure();
188  return success();
189  }
190  // Set control to "None" otherwise.
191  Builder builder = parser.getBuilder();
192  state.addAttribute(attrName, builder.getI32IntegerAttr(0));
193  return success();
194 }
195 
196 /// Parses optional memory access attributes attached to a memory access
197 /// operand/pointer. Specifically, parses the following syntax:
198 /// (`[` memory-access `]`)?
199 /// where:
200 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
201 /// integer-literal | `"NonTemporal"`
203  OperationState &state) {
204  // Parse an optional list of attributes staring with '['
205  if (parser.parseOptionalLSquare()) {
206  // Nothing to do
207  return success();
208  }
209 
210  spirv::MemoryAccess memoryAccessAttr;
211  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
213  return failure();
214  }
215 
216  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
217  // Parse integer attribute for alignment.
218  Attribute alignmentAttr;
219  Type i32Type = parser.getBuilder().getIntegerType(32);
220  if (parser.parseComma() ||
221  parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
222  state.attributes)) {
223  return failure();
224  }
225  }
226  return parser.parseRSquare();
227 }
228 
229 // TODO Make sure to merge this and the previous function into one template
230 // parameterized by memory access attribute name and alignment. Doing so now
231 // results in VS2017 in producing an internal error (at the call site) that's
232 // not detailed enough to understand what is happening.
234  OperationState &state) {
235  // Parse an optional list of attributes staring with '['
236  if (parser.parseOptionalLSquare()) {
237  // Nothing to do
238  return success();
239  }
240 
241  spirv::MemoryAccess memoryAccessAttr;
242  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
244  return failure();
245  }
246 
247  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
248  // Parse integer attribute for alignment.
249  Attribute alignmentAttr;
250  Type i32Type = parser.getBuilder().getIntegerType(32);
251  if (parser.parseComma() ||
252  parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
253  state.attributes)) {
254  return failure();
255  }
256  }
257  return parser.parseRSquare();
258 }
259 
260 template <typename MemoryOpTy>
262  MemoryOpTy memoryOp, OpAsmPrinter &printer,
263  SmallVectorImpl<StringRef> &elidedAttrs,
264  Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
265  Optional<uint32_t> alignmentAttrValue = None) {
266  // Print optional memory access attribute.
267  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
268  : memoryOp.memory_access())) {
269  elidedAttrs.push_back(kMemoryAccessAttrName);
270 
271  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
272 
273  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
274  // Print integer alignment attribute.
275  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
276  : memoryOp.alignment())) {
277  elidedAttrs.push_back(kAlignmentAttrName);
278  printer << ", " << alignment;
279  }
280  }
281  printer << "]";
282  }
283  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
284 }
285 
286 // TODO Make sure to merge this and the previous function into one template
287 // parameterized by memory access attribute name and alignment. Doing so now
288 // results in VS2017 in producing an internal error (at the call site) that's
289 // not detailed enough to understand what is happening.
290 template <typename MemoryOpTy>
292  MemoryOpTy memoryOp, OpAsmPrinter &printer,
293  SmallVectorImpl<StringRef> &elidedAttrs,
294  Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
295  Optional<uint32_t> alignmentAttrValue = None) {
296 
297  printer << ", ";
298 
299  // Print optional memory access attribute.
300  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
301  : memoryOp.memory_access())) {
302  elidedAttrs.push_back(kSourceMemoryAccessAttrName);
303 
304  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
305 
306  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
307  // Print integer alignment attribute.
308  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
309  : memoryOp.alignment())) {
310  elidedAttrs.push_back(kSourceAlignmentAttrName);
311  printer << ", " << alignment;
312  }
313  }
314  printer << "]";
315  }
316  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
317 }
318 
320  spirv::ImageOperandsAttr &attr) {
321  // Expect image operands
322  if (parser.parseOptionalLSquare())
323  return success();
324 
325  spirv::ImageOperands imageOperands;
326  if (parseEnumStrAttr(imageOperands, parser))
327  return failure();
328 
329  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
330 
331  return parser.parseRSquare();
332 }
333 
334 static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
335  spirv::ImageOperandsAttr attr) {
336  if (attr) {
337  auto strImageOperands = stringifyImageOperands(attr.getValue());
338  printer << "[\"" << strImageOperands << "\"]";
339  }
340 }
341 
342 template <typename Op>
344  spirv::ImageOperandsAttr attr,
345  Operation::operand_range operands) {
346  if (!attr) {
347  if (operands.empty())
348  return success();
349 
350  return imageOp.emitError("the Image Operands should encode what operands "
351  "follow, as per Image Operands");
352  }
353 
354  // TODO: Add the validation rules for the following Image Operands.
355  spirv::ImageOperands noSupportOperands =
356  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
357  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
358  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
359  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
360  spirv::ImageOperands::MakeTexelAvailable |
361  spirv::ImageOperands::MakeTexelVisible |
362  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
363 
364  if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
365  llvm_unreachable("unimplemented operands of Image Operands");
366 
367  return success();
368 }
369 
371  bool requireSameBitWidth = true,
372  bool skipBitWidthCheck = false) {
373  // Some CastOps have no limit on bit widths for result and operand type.
374  if (skipBitWidthCheck)
375  return success();
376 
377  Type operandType = op->getOperand(0).getType();
378  Type resultType = op->getResult(0).getType();
379 
380  // ODS checks that result type and operand type have the same shape.
381  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
382  operandType = vectorType.getElementType();
383  resultType = resultType.cast<VectorType>().getElementType();
384  }
385 
386  if (auto coopMatrixType =
387  operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
388  operandType = coopMatrixType.getElementType();
389  resultType =
391  }
392 
393  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
394  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
395  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
396 
397  if (requireSameBitWidth) {
398  if (!isSameBitWidth) {
399  return op->emitOpError(
400  "expected the same bit widths for operand type and result "
401  "type, but provided ")
402  << operandType << " and " << resultType;
403  }
404  return success();
405  }
406 
407  if (isSameBitWidth) {
408  return op->emitOpError(
409  "expected the different bit widths for operand type and result "
410  "type, but provided ")
411  << operandType << " and " << resultType;
412  }
413  return success();
414 }
415 
416 template <typename MemoryOpTy>
417 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
418  // ODS checks for attributes values. Just need to verify that if the
419  // memory-access attribute is Aligned, then the alignment attribute must be
420  // present.
421  auto *op = memoryOp.getOperation();
422  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
423  if (!memAccessAttr) {
424  // Alignment attribute shouldn't be present if memory access attribute is
425  // not present.
426  if (op->getAttr(kAlignmentAttrName)) {
427  return memoryOp.emitOpError(
428  "invalid alignment specification without aligned memory access "
429  "specification");
430  }
431  return success();
432  }
433 
434  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
435  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
436 
437  if (!memAccess) {
438  return memoryOp.emitOpError("invalid memory access specifier: ")
439  << memAccessVal;
440  }
441 
442  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
443  if (!op->getAttr(kAlignmentAttrName)) {
444  return memoryOp.emitOpError("missing alignment value");
445  }
446  } else {
447  if (op->getAttr(kAlignmentAttrName)) {
448  return memoryOp.emitOpError(
449  "invalid alignment specification with non-aligned memory access "
450  "specification");
451  }
452  }
453  return success();
454 }
455 
456 // TODO Make sure to merge this and the previous function into one template
457 // parameterized by memory access attribute name and alignment. Doing so now
458 // results in VS2017 in producing an internal error (at the call site) that's
459 // not detailed enough to understand what is happening.
460 template <typename MemoryOpTy>
462  // ODS checks for attributes values. Just need to verify that if the
463  // memory-access attribute is Aligned, then the alignment attribute must be
464  // present.
465  auto *op = memoryOp.getOperation();
466  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
467  if (!memAccessAttr) {
468  // Alignment attribute shouldn't be present if memory access attribute is
469  // not present.
470  if (op->getAttr(kSourceAlignmentAttrName)) {
471  return memoryOp.emitOpError(
472  "invalid alignment specification without aligned memory access "
473  "specification");
474  }
475  return success();
476  }
477 
478  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
479  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
480 
481  if (!memAccess) {
482  return memoryOp.emitOpError("invalid memory access specifier: ")
483  << memAccessVal;
484  }
485 
486  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
487  if (!op->getAttr(kSourceAlignmentAttrName)) {
488  return memoryOp.emitOpError("missing alignment value");
489  }
490  } else {
491  if (op->getAttr(kSourceAlignmentAttrName)) {
492  return memoryOp.emitOpError(
493  "invalid alignment specification with non-aligned memory access "
494  "specification");
495  }
496  }
497  return success();
498 }
499 
500 static LogicalResult
501 verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
502  // According to the SPIR-V specification:
503  // "Despite being a mask and allowing multiple bits to be combined, it is
504  // invalid for more than one of these four bits to be set: Acquire, Release,
505  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
506  // Release semantics is done by setting the AcquireRelease bit, not by setting
507  // two bits."
508  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
509  spirv::MemorySemantics::Release |
510  spirv::MemorySemantics::AcquireRelease |
511  spirv::MemorySemantics::SequentiallyConsistent;
512 
513  auto bitCount = llvm::countPopulation(
514  static_cast<uint32_t>(memorySemantics & atMostOneInSet));
515  if (bitCount > 1) {
516  return op->emitError(
517  "expected at most one of these four memory constraints "
518  "to be set: `Acquire`, `Release`,"
519  "`AcquireRelease` or `SequentiallyConsistent`");
520  }
521  return success();
522 }
523 
524 template <typename LoadStoreOpTy>
525 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
526  Value val) {
527  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
528  // type of the pointer and the type of the value are the same
529  //
530  // TODO: Check that the value type satisfies restrictions of
531  // SPIR-V OpLoad/OpStore operations
532  if (val.getType() !=
533  ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
534  return op.emitOpError("mismatch in result type and pointer type");
535  }
536  return success();
537 }
538 
539 template <typename BlockReadWriteOpTy>
541  Value ptr, Value val) {
542  auto valType = val.getType();
543  if (auto valVecTy = valType.dyn_cast<VectorType>())
544  valType = valVecTy.getElementType();
545 
546  if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
547  return op.emitOpError("mismatch in result type and pointer type");
548  }
549  return success();
550 }
551 
553  OperationState &state) {
554  auto builtInName = llvm::convertToSnakeFromCamelCase(
555  stringifyDecoration(spirv::Decoration::BuiltIn));
556  if (succeeded(parser.parseOptionalKeyword("bind"))) {
557  Attribute set, binding;
558  // Parse optional descriptor binding
559  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
560  stringifyDecoration(spirv::Decoration::DescriptorSet));
561  auto bindingName = llvm::convertToSnakeFromCamelCase(
562  stringifyDecoration(spirv::Decoration::Binding));
563  Type i32Type = parser.getBuilder().getIntegerType(32);
564  if (parser.parseLParen() ||
565  parser.parseAttribute(set, i32Type, descriptorSetName,
566  state.attributes) ||
567  parser.parseComma() ||
568  parser.parseAttribute(binding, i32Type, bindingName,
569  state.attributes) ||
570  parser.parseRParen()) {
571  return failure();
572  }
573  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
574  StringAttr builtIn;
575  if (parser.parseLParen() ||
576  parser.parseAttribute(builtIn, builtInName, state.attributes) ||
577  parser.parseRParen()) {
578  return failure();
579  }
580  }
581 
582  // Parse other attributes
583  if (parser.parseOptionalAttrDict(state.attributes))
584  return failure();
585 
586  return success();
587 }
588 
590  SmallVectorImpl<StringRef> &elidedAttrs) {
591  // Print optional descriptor binding
592  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
593  stringifyDecoration(spirv::Decoration::DescriptorSet));
594  auto bindingName = llvm::convertToSnakeFromCamelCase(
595  stringifyDecoration(spirv::Decoration::Binding));
596  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
597  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
598  if (descriptorSet && binding) {
599  elidedAttrs.push_back(descriptorSetName);
600  elidedAttrs.push_back(bindingName);
601  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
602  << ")";
603  }
604 
605  // Print BuiltIn attribute if present
606  auto builtInName = llvm::convertToSnakeFromCamelCase(
607  stringifyDecoration(spirv::Decoration::BuiltIn));
608  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
609  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
610  elidedAttrs.push_back(builtInName);
611  }
612 
613  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
614 }
615 
616 // Get bit width of types.
617 static unsigned getBitWidth(Type type) {
618  if (type.isa<spirv::PointerType>()) {
619  // Just return 64 bits for pointer types for now.
620  // TODO: Make sure not caller relies on the actual pointer width value.
621  return 64;
622  }
623 
624  if (type.isIntOrFloat())
625  return type.getIntOrFloatBitWidth();
626 
627  if (auto vectorType = type.dyn_cast<VectorType>()) {
628  assert(vectorType.getElementType().isIntOrFloat());
629  return vectorType.getNumElements() *
630  vectorType.getElementType().getIntOrFloatBitWidth();
631  }
632  llvm_unreachable("unhandled bit width computation for type");
633 }
634 
635 /// Walks the given type hierarchy with the given indices, potentially down
636 /// to component granularity, to select an element type. Returns null type and
637 /// emits errors with the given loc on failure.
638 static Type
640  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
641  if (indices.empty()) {
642  emitErrorFn("expected at least one index for spv.CompositeExtract");
643  return nullptr;
644  }
645 
646  for (auto index : indices) {
647  if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
648  if (cType.hasCompileTimeKnownNumElements() &&
649  (index < 0 ||
650  static_cast<uint64_t>(index) >= cType.getNumElements())) {
651  emitErrorFn("index ") << index << " out of bounds for " << type;
652  return nullptr;
653  }
654  type = cType.getElementType(index);
655  } else {
656  emitErrorFn("cannot extract from non-composite type ")
657  << type << " with index " << index;
658  return nullptr;
659  }
660  }
661  return type;
662 }
663 
664 static Type
666  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
667  auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
668  if (!indicesArrayAttr) {
669  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
670  return nullptr;
671  }
672  if (indicesArrayAttr.empty()) {
673  emitErrorFn("expected at least one index for spv.CompositeExtract");
674  return nullptr;
675  }
676 
677  SmallVector<int32_t, 2> indexVals;
678  for (auto indexAttr : indicesArrayAttr) {
679  auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
680  if (!indexIntAttr) {
681  emitErrorFn("expected an 32-bit integer for index, but found '")
682  << indexAttr << "'";
683  return nullptr;
684  }
685  indexVals.push_back(indexIntAttr.getInt());
686  }
687  return getElementType(type, indexVals, emitErrorFn);
688 }
689 
690 static Type getElementType(Type type, Attribute indices, Location loc) {
691  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
692  return ::mlir::emitError(loc, err);
693  };
694  return getElementType(type, indices, errorFn);
695 }
696 
697 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
698  llvm::SMLoc loc) {
699  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
700  return parser.emitError(loc, err);
701  };
702  return getElementType(type, indices, errorFn);
703 }
704 
705 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
706 static inline bool isMergeBlock(Block &block) {
707  return !block.empty() && std::next(block.begin()) == block.end() &&
708  isa<spirv::MergeOp>(block.front());
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // Common parsers and printers
713 //===----------------------------------------------------------------------===//
714 
715 // Parses an atomic update op. If the update op does not take a value (like
716 // AtomicIIncrement) `hasValue` must be false.
718  OperationState &state, bool hasValue) {
719  spirv::Scope scope;
720  spirv::MemorySemantics memoryScope;
722  OpAsmParser::OperandType ptrInfo, valueInfo;
723  Type type;
724  llvm::SMLoc loc;
725  if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
726  parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
727  parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
728  parser.getCurrentLocation(&loc) || parser.parseColonType(type))
729  return failure();
730 
731  auto ptrType = type.dyn_cast<spirv::PointerType>();
732  if (!ptrType)
733  return parser.emitError(loc, "expected pointer type");
734 
735  SmallVector<Type, 2> operandTypes;
736  operandTypes.push_back(ptrType);
737  if (hasValue)
738  operandTypes.push_back(ptrType.getPointeeType());
739  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
740  state.operands))
741  return failure();
742  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
743 }
744 
745 // Prints an atomic update op.
746 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
747  printer << " \"";
748  auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
749  printer << spirv::stringifyScope(
750  static_cast<spirv::Scope>(scopeAttr.getInt()))
751  << "\" \"";
752  auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
753  printer << spirv::stringifyMemorySemantics(
754  static_cast<spirv::MemorySemantics>(
755  memorySemanticsAttr.getInt()))
756  << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
757 }
758 
759 template <typename T>
760 static StringRef stringifyTypeName();
761 
762 template <>
764  return "integer";
765 }
766 
767 template <>
769  return "float";
770 }
771 
772 // Verifies an atomic update op.
773 template <typename ExpectedElementType>
775  auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
776  auto elementType = ptrType.getPointeeType();
777  if (!elementType.isa<ExpectedElementType>())
778  return op->emitOpError() << "pointer operand must point to an "
779  << stringifyTypeName<ExpectedElementType>()
780  << " value, found " << elementType;
781 
782  if (op->getNumOperands() > 1) {
783  auto valueType = op->getOperand(1).getType();
784  if (valueType != elementType)
785  return op->emitOpError("expected value to have the same type as the "
786  "pointer operand's pointee type ")
787  << elementType << ", but found " << valueType;
788  }
789  auto memorySemantics = static_cast<spirv::MemorySemantics>(
790  op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
791  if (failed(verifyMemorySemantics(op, memorySemantics))) {
792  return failure();
793  }
794  return success();
795 }
796 
798  OperationState &state) {
799  spirv::Scope executionScope;
800  spirv::GroupOperation groupOperation;
801  OpAsmParser::OperandType valueInfo;
802  if (parseEnumStrAttr(executionScope, parser, state,
804  parseEnumStrAttr(groupOperation, parser, state,
806  parser.parseOperand(valueInfo))
807  return failure();
808 
809  Optional<OpAsmParser::OperandType> clusterSizeInfo;
811  clusterSizeInfo = OpAsmParser::OperandType();
812  if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
813  parser.parseRParen())
814  return failure();
815  }
816 
817  Type resultType;
818  if (parser.parseColonType(resultType))
819  return failure();
820 
821  if (parser.resolveOperand(valueInfo, resultType, state.operands))
822  return failure();
823 
824  if (clusterSizeInfo.hasValue()) {
825  Type i32Type = parser.getBuilder().getIntegerType(32);
826  if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
827  return failure();
828  }
829 
830  return parser.addTypeToList(resultType, state.types);
831 }
832 
834  OpAsmPrinter &printer) {
835  printer << " \""
836  << stringifyScope(static_cast<spirv::Scope>(
837  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
838  .getInt()))
839  << "\" \""
840  << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
841  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
842  .getInt()))
843  << "\" " << groupOp->getOperand(0);
844 
845  if (groupOp->getNumOperands() > 1)
846  printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
847  printer << " : " << groupOp->getResult(0).getType();
848 }
849 
851  spirv::Scope scope = static_cast<spirv::Scope>(
852  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
853  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
854  return groupOp->emitOpError(
855  "execution scope must be 'Workgroup' or 'Subgroup'");
856 
857  spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
858  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
859  if (operation == spirv::GroupOperation::ClusteredReduce &&
860  groupOp->getNumOperands() == 1)
861  return groupOp->emitOpError("cluster size operand must be provided for "
862  "'ClusteredReduce' group operation");
863  if (groupOp->getNumOperands() > 1) {
864  Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
865  int32_t clusterSize = 0;
866 
867  // TODO: support specialization constant here.
868  if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
869  return groupOp->emitOpError(
870  "cluster size operand must come from a constant op");
871 
872  if (!llvm::isPowerOf2_32(clusterSize))
873  return groupOp->emitOpError(
874  "cluster size operand must be a power of two");
875  }
876  return success();
877 }
878 
880  OpAsmParser::OperandType operandInfo;
881  Type type;
882  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
883  parser.resolveOperands(operandInfo, type, state.operands)) {
884  return failure();
885  }
886  state.addTypes(type);
887  return success();
888 }
889 
890 static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
891  printer << ' ' << unaryOp->getOperand(0) << " : "
892  << unaryOp->getOperand(0).getType();
893 }
894 
895 /// Result of a logical op must be a scalar or vector of boolean type.
896 static Type getUnaryOpResultType(Builder &builder, Type operandType) {
897  Type resultType = builder.getIntegerType(1);
898  if (auto vecType = operandType.dyn_cast<VectorType>()) {
899  return VectorType::get(vecType.getNumElements(), resultType);
900  }
901  return resultType;
902 }
903 
905  OperationState &state) {
906  OpAsmParser::OperandType operandInfo;
907  Type type;
908  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
909  parser.resolveOperand(operandInfo, type, state.operands)) {
910  return failure();
911  }
912  state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
913  return success();
914 }
915 
917  OperationState &result) {
919  Type type;
920  if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
921  parser.resolveOperands(ops, type, result.operands)) {
922  return failure();
923  }
924  result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
925  return success();
926 }
927 
928 static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
929  printer << ' ' << logicalOp->getOperands() << " : "
930  << logicalOp->getOperand(0).getType();
931 }
932 
935  Type baseType;
936  Type shiftType;
937  auto loc = parser.getCurrentLocation();
938 
939  if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
940  parser.parseType(baseType) || parser.parseComma() ||
941  parser.parseType(shiftType) ||
942  parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
943  state.operands)) {
944  return failure();
945  }
946  state.addTypes(baseType);
947  return success();
948 }
949 
950 static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
951  Value base = op->getOperand(0);
952  Value shift = op->getOperand(1);
953  printer << ' ' << base << ", " << shift << " : " << base.getType() << ", "
954  << shift.getType();
955 }
956 
958  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
959  return op->emitError("expected the same type for the first operand and "
960  "result, but provided ")
961  << op->getOperand(0).getType() << " and "
962  << op->getResult(0).getType();
963  }
964  return success();
965 }
966 
967 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
968  Value lhs, Value rhs) {
969  assert(lhs.getType() == rhs.getType());
970 
971  Type boolType = builder.getI1Type();
972  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
973  boolType = VectorType::get(vecType.getShape(), boolType);
974  state.addTypes(boolType);
975 
976  state.addOperands({lhs, rhs});
977 }
978 
979 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
980  Value value) {
981  Type boolType = builder.getI1Type();
982  if (auto vecType = value.getType().dyn_cast<VectorType>())
983  boolType = VectorType::get(vecType.getShape(), boolType);
984  state.addTypes(boolType);
985 
986  state.addOperands(value);
987 }
988 
989 //===----------------------------------------------------------------------===//
990 // spv.AccessChainOp
991 //===----------------------------------------------------------------------===//
992 
993 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
994  auto ptrType = type.dyn_cast<spirv::PointerType>();
995  if (!ptrType) {
996  emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
997  "to composite type, but provided ")
998  << type;
999  return nullptr;
1000  }
1001 
1002  auto resultType = ptrType.getPointeeType();
1003  auto resultStorageClass = ptrType.getStorageClass();
1004  int32_t index = 0;
1005 
1006  for (auto indexSSA : indices) {
1007  auto cType = resultType.dyn_cast<spirv::CompositeType>();
1008  if (!cType) {
1009  emitError(baseLoc,
1010  "'spv.AccessChain' op cannot extract from non-composite type ")
1011  << resultType << " with index " << index;
1012  return nullptr;
1013  }
1014  index = 0;
1015  if (resultType.isa<spirv::StructType>()) {
1016  Operation *op = indexSSA.getDefiningOp();
1017  if (!op) {
1018  emitError(baseLoc, "'spv.AccessChain' op index must be an "
1019  "integer spv.Constant to access "
1020  "element of spv.struct");
1021  return nullptr;
1022  }
1023 
1024  // TODO: this should be relaxed to allow
1025  // integer literals of other bitwidths.
1026  if (failed(extractValueFromConstOp(op, index))) {
1027  emitError(baseLoc,
1028  "'spv.AccessChain' index must be an integer spv.Constant to "
1029  "access element of spv.struct, but provided ")
1030  << op->getName();
1031  return nullptr;
1032  }
1033  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1034  emitError(baseLoc, "'spv.AccessChain' op index ")
1035  << index << " out of bounds for " << resultType;
1036  return nullptr;
1037  }
1038  }
1039  resultType = cType.getElementType(index);
1040  }
1041  return spirv::PointerType::get(resultType, resultStorageClass);
1042 }
1043 
1044 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1045  Value basePtr, ValueRange indices) {
1046  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1047  assert(type && "Unable to deduce return type based on basePtr and indices");
1048  build(builder, state, type, basePtr, indices);
1049 }
1050 
1052  OperationState &state) {
1053  OpAsmParser::OperandType ptrInfo;
1055  Type type;
1056  auto loc = parser.getCurrentLocation();
1057  SmallVector<Type, 4> indicesTypes;
1058 
1059  if (parser.parseOperand(ptrInfo) ||
1060  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1061  parser.parseColonType(type) ||
1062  parser.resolveOperand(ptrInfo, type, state.operands)) {
1063  return failure();
1064  }
1065 
1066  // Check that the provided indices list is not empty before parsing their
1067  // type list.
1068  if (indicesInfo.empty()) {
1069  return emitError(state.location, "'spv.AccessChain' op expected at "
1070  "least one index ");
1071  }
1072 
1073  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1074  return failure();
1075 
1076  // Check that the indices types list is not empty and that it has a one-to-one
1077  // mapping to the provided indices.
1078  if (indicesTypes.size() != indicesInfo.size()) {
1079  return emitError(state.location, "'spv.AccessChain' op indices "
1080  "types' count must be equal to indices "
1081  "info count");
1082  }
1083 
1084  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
1085  return failure();
1086 
1087  auto resultType = getElementPtrType(
1088  type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
1089  if (!resultType) {
1090  return failure();
1091  }
1092 
1093  state.addTypes(resultType);
1094  return success();
1095 }
1096 
1097 template <typename Op>
1098 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1099  printer << ' ' << op.base_ptr() << '[' << indices
1100  << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
1101 }
1102 
1103 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
1104  printAccessChain(op, op.indices(), printer);
1105 }
1106 
1107 template <typename Op>
1108 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1109  auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1110  indices, accessChainOp.getLoc());
1111  if (!resultType)
1112  return failure();
1113 
1114  auto providedResultType =
1115  accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1116  if (!providedResultType)
1117  return accessChainOp.emitOpError(
1118  "result type must be a pointer, but provided")
1119  << providedResultType;
1120 
1121  if (resultType != providedResultType)
1122  return accessChainOp.emitOpError("invalid result type: expected ")
1123  << resultType << ", but provided " << providedResultType;
1124 
1125  return success();
1126 }
1127 
1128 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
1129  return verifyAccessChain(accessChainOp, accessChainOp.indices());
1130 }
1131 
1132 //===----------------------------------------------------------------------===//
1133 // spv.mlir.addressof
1134 //===----------------------------------------------------------------------===//
1135 
1136 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1137  spirv::GlobalVariableOp var) {
1138  build(builder, state, var.type(), SymbolRefAttr::get(var));
1139 }
1140 
1141 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
1142  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1143  SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
1144  addressOfOp.variableAttr()));
1145  if (!varOp) {
1146  return addressOfOp.emitOpError("expected spv.GlobalVariable symbol");
1147  }
1148  if (addressOfOp.pointer().getType() != varOp.type()) {
1149  return addressOfOp.emitOpError(
1150  "result type mismatch with the referenced global variable's type");
1151  }
1152  return success();
1153 }
1154 
1155 template <typename T>
1156 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1157  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
1158  << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1159  << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1160  << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1161 }
1162 
1164  OperationState &state) {
1165  spirv::Scope memoryScope;
1166  spirv::MemorySemantics equalSemantics, unequalSemantics;
1168  Type type;
1169  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1170  parseEnumStrAttr(equalSemantics, parser, state,
1172  parseEnumStrAttr(unequalSemantics, parser, state,
1174  parser.parseOperandList(operandInfo, 3))
1175  return failure();
1176 
1177  auto loc = parser.getCurrentLocation();
1178  if (parser.parseColonType(type))
1179  return failure();
1180 
1181  auto ptrType = type.dyn_cast<spirv::PointerType>();
1182  if (!ptrType)
1183  return parser.emitError(loc, "expected pointer type");
1184 
1185  if (parser.resolveOperands(
1186  operandInfo,
1187  {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1188  parser.getNameLoc(), state.operands))
1189  return failure();
1190 
1191  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1192 }
1193 
1194 template <typename T>
1196  // According to the spec:
1197  // "The type of Value must be the same as Result Type. The type of the value
1198  // pointed to by Pointer must be the same as Result Type. This type must also
1199  // match the type of Comparator."
1200  if (atomOp.getType() != atomOp.value().getType())
1201  return atomOp.emitOpError("value operand must have the same type as the op "
1202  "result, but found ")
1203  << atomOp.value().getType() << " vs " << atomOp.getType();
1204 
1205  if (atomOp.getType() != atomOp.comparator().getType())
1206  return atomOp.emitOpError(
1207  "comparator operand must have the same type as the op "
1208  "result, but found ")
1209  << atomOp.comparator().getType() << " vs " << atomOp.getType();
1210 
1211  Type pointeeType = atomOp.pointer()
1212  .getType()
1213  .template cast<spirv::PointerType>()
1214  .getPointeeType();
1215  if (atomOp.getType() != pointeeType)
1216  return atomOp.emitOpError(
1217  "pointer operand's pointee type must have the same "
1218  "as the op result type, but found ")
1219  << pointeeType << " vs " << atomOp.getType();
1220 
1221  // TODO: Unequal cannot be set to Release or Acquire and Release.
1222  // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1223 
1224  return success();
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // spv.AtomicExchange
1229 //===----------------------------------------------------------------------===//
1230 
1231 static void print(spirv::AtomicExchangeOp atomOp, OpAsmPrinter &printer) {
1232  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
1233  << stringifyMemorySemantics(atomOp.semantics()) << "\" "
1234  << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1235 }
1236 
1238  OperationState &state) {
1239  spirv::Scope memoryScope;
1240  spirv::MemorySemantics semantics;
1242  Type type;
1243  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1244  parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
1245  parser.parseOperandList(operandInfo, 2))
1246  return failure();
1247 
1248  auto loc = parser.getCurrentLocation();
1249  if (parser.parseColonType(type))
1250  return failure();
1251 
1252  auto ptrType = type.dyn_cast<spirv::PointerType>();
1253  if (!ptrType)
1254  return parser.emitError(loc, "expected pointer type");
1255 
1256  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1257  parser.getNameLoc(), state.operands))
1258  return failure();
1259 
1260  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1261 }
1262 
1263 static LogicalResult verify(spirv::AtomicExchangeOp atomOp) {
1264  if (atomOp.getType() != atomOp.value().getType())
1265  return atomOp.emitOpError("value operand must have the same type as the op "
1266  "result, but found ")
1267  << atomOp.value().getType() << " vs " << atomOp.getType();
1268 
1269  Type pointeeType =
1270  atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
1271  if (atomOp.getType() != pointeeType)
1272  return atomOp.emitOpError(
1273  "pointer operand's pointee type must have the same "
1274  "as the op result type, but found ")
1275  << pointeeType << " vs " << atomOp.getType();
1276 
1277  return success();
1278 }
1279 
1280 //===----------------------------------------------------------------------===//
1281 // spv.BitcastOp
1282 //===----------------------------------------------------------------------===//
1283 
1284 static LogicalResult verify(spirv::BitcastOp bitcastOp) {
1285  // TODO: The SPIR-V spec validation rules are different for different
1286  // versions.
1287  auto operandType = bitcastOp.operand().getType();
1288  auto resultType = bitcastOp.result().getType();
1289  if (operandType == resultType) {
1290  return bitcastOp.emitError(
1291  "result type must be different from operand type");
1292  }
1293  if (operandType.isa<spirv::PointerType>() &&
1294  !resultType.isa<spirv::PointerType>()) {
1295  return bitcastOp.emitError(
1296  "unhandled bit cast conversion from pointer type to non-pointer type");
1297  }
1298  if (!operandType.isa<spirv::PointerType>() &&
1299  resultType.isa<spirv::PointerType>()) {
1300  return bitcastOp.emitError(
1301  "unhandled bit cast conversion from non-pointer type to pointer type");
1302  }
1303  auto operandBitWidth = getBitWidth(operandType);
1304  auto resultBitWidth = getBitWidth(resultType);
1305  if (operandBitWidth != resultBitWidth) {
1306  return bitcastOp.emitOpError("mismatch in result type bitwidth ")
1307  << resultBitWidth << " and operand type bitwidth "
1308  << operandBitWidth;
1309  }
1310  return success();
1311 }
1312 
1313 //===----------------------------------------------------------------------===//
1314 // spv.BranchOp
1315 //===----------------------------------------------------------------------===//
1316 
1318 spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
1319  assert(index == 0 && "invalid successor index");
1320  return targetOperandsMutable();
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // spv.BranchConditionalOp
1325 //===----------------------------------------------------------------------===//
1326 
1328 spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
1329  assert(index < 2 && "invalid successor index");
1330  return index == kTrueIndex ? trueTargetOperandsMutable()
1331  : falseTargetOperandsMutable();
1332 }
1333 
1335  OperationState &state) {
1336  auto &builder = parser.getBuilder();
1337  OpAsmParser::OperandType condInfo;
1338  Block *dest;
1339 
1340  // Parse the condition.
1341  Type boolTy = builder.getI1Type();
1342  if (parser.parseOperand(condInfo) ||
1343  parser.resolveOperand(condInfo, boolTy, state.operands))
1344  return failure();
1345 
1346  // Parse the optional branch weights.
1347  if (succeeded(parser.parseOptionalLSquare())) {
1348  IntegerAttr trueWeight, falseWeight;
1349  NamedAttrList weights;
1350 
1351  auto i32Type = builder.getIntegerType(32);
1352  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1353  parser.parseComma() ||
1354  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1355  parser.parseRSquare())
1356  return failure();
1357 
1359  builder.getArrayAttr({trueWeight, falseWeight}));
1360  }
1361 
1362  // Parse the true branch.
1363  SmallVector<Value, 4> trueOperands;
1364  if (parser.parseComma() ||
1365  parser.parseSuccessorAndUseList(dest, trueOperands))
1366  return failure();
1367  state.addSuccessors(dest);
1368  state.addOperands(trueOperands);
1369 
1370  // Parse the false branch.
1371  SmallVector<Value, 4> falseOperands;
1372  if (parser.parseComma() ||
1373  parser.parseSuccessorAndUseList(dest, falseOperands))
1374  return failure();
1375  state.addSuccessors(dest);
1376  state.addOperands(falseOperands);
1377  state.addAttribute(
1378  spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1379  builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1380  static_cast<int32_t>(falseOperands.size())}));
1381 
1382  return success();
1383 }
1384 
1385 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
1386  printer << ' ' << branchOp.condition();
1387 
1388  if (auto weights = branchOp.branch_weights()) {
1389  printer << " [";
1390  llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1391  printer << a.cast<IntegerAttr>().getInt();
1392  });
1393  printer << "]";
1394  }
1395 
1396  printer << ", ";
1397  printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
1398  branchOp.getTrueBlockArguments());
1399  printer << ", ";
1400  printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
1401  branchOp.getFalseBlockArguments());
1402 }
1403 
1404 static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
1405  if (auto weights = branchOp.branch_weights()) {
1406  if (weights->getValue().size() != 2) {
1407  return branchOp.emitOpError("must have exactly two branch weights");
1408  }
1409  if (llvm::all_of(*weights, [](Attribute attr) {
1410  return attr.cast<IntegerAttr>().getValue().isNullValue();
1411  }))
1412  return branchOp.emitOpError("branch weights cannot both be zero");
1413  }
1414 
1415  return success();
1416 }
1417 
1418 //===----------------------------------------------------------------------===//
1419 // spv.CompositeConstruct
1420 //===----------------------------------------------------------------------===//
1421 
1423  OperationState &state) {
1425  Type type;
1426  auto loc = parser.getCurrentLocation();
1427 
1428  if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1429  return failure();
1430  }
1431  auto cType = type.dyn_cast<spirv::CompositeType>();
1432  if (!cType) {
1433  return parser.emitError(
1434  loc, "result type must be a composite type, but provided ")
1435  << type;
1436  }
1437 
1438  if (cType.hasCompileTimeKnownNumElements() &&
1439  operands.size() != cType.getNumElements()) {
1440  return parser.emitError(loc, "has incorrect number of operands: expected ")
1441  << cType.getNumElements() << ", but provided " << operands.size();
1442  }
1443  // TODO: Add support for constructing a vector type from the vector operands.
1444  // According to the spec: "for constructing a vector, the operands may
1445  // also be vectors with the same component type as the Result Type component
1446  // type".
1447  SmallVector<Type, 4> elementTypes;
1448  elementTypes.reserve(operands.size());
1449  for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1450  elementTypes.push_back(cType.getElementType(index));
1451  }
1452  state.addTypes(type);
1453  return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1454 }
1455 
1456 static void print(spirv::CompositeConstructOp compositeConstructOp,
1457  OpAsmPrinter &printer) {
1458  printer << " " << compositeConstructOp.constituents() << " : "
1459  << compositeConstructOp.getResult().getType();
1460 }
1461 
1462 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
1463  auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
1464  SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
1465 
1466  if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1467  if (constituents.size() != 1)
1468  return compositeConstructOp.emitError(
1469  "has incorrect number of operands: expected ")
1470  << "1, but provided " << constituents.size();
1471  } else if (constituents.size() != cType.getNumElements()) {
1472  return compositeConstructOp.emitError(
1473  "has incorrect number of operands: expected ")
1474  << cType.getNumElements() << ", but provided "
1475  << constituents.size();
1476  }
1477 
1478  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1479  if (constituents[index].getType() != cType.getElementType(index)) {
1480  return compositeConstructOp.emitError(
1481  "operand type mismatch: expected operand type ")
1482  << cType.getElementType(index) << ", but provided "
1483  << constituents[index].getType();
1484  }
1485  }
1486 
1487  return success();
1488 }
1489 
1490 //===----------------------------------------------------------------------===//
1491 // spv.CompositeExtractOp
1492 //===----------------------------------------------------------------------===//
1493 
1494 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1495  Value composite,
1496  ArrayRef<int32_t> indices) {
1497  auto indexAttr = builder.getI32ArrayAttr(indices);
1498  auto elementType =
1499  getElementType(composite.getType(), indexAttr, state.location);
1500  if (!elementType) {
1501  return;
1502  }
1503  build(builder, state, elementType, composite, indexAttr);
1504 }
1505 
1507  OperationState &state) {
1508  OpAsmParser::OperandType compositeInfo;
1509  Attribute indicesAttr;
1510  Type compositeType;
1511  llvm::SMLoc attrLocation;
1512 
1513  if (parser.parseOperand(compositeInfo) ||
1514  parser.getCurrentLocation(&attrLocation) ||
1515  parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1516  parser.parseColonType(compositeType) ||
1517  parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1518  return failure();
1519  }
1520 
1521  Type resultType =
1522  getElementType(compositeType, indicesAttr, parser, attrLocation);
1523  if (!resultType) {
1524  return failure();
1525  }
1526  state.addTypes(resultType);
1527  return success();
1528 }
1529 
1530 static void print(spirv::CompositeExtractOp compositeExtractOp,
1531  OpAsmPrinter &printer) {
1532  printer << ' ' << compositeExtractOp.composite()
1533  << compositeExtractOp.indices() << " : "
1534  << compositeExtractOp.composite().getType();
1535 }
1536 
1537 static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
1538  auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
1539  auto resultType = getElementType(compExOp.composite().getType(),
1540  indicesArrayAttr, compExOp.getLoc());
1541  if (!resultType)
1542  return failure();
1543 
1544  if (resultType != compExOp.getType()) {
1545  return compExOp.emitOpError("invalid result type: expected ")
1546  << resultType << " but provided " << compExOp.getType();
1547  }
1548 
1549  return success();
1550 }
1551 
1552 //===----------------------------------------------------------------------===//
1553 // spv.CompositeInsert
1554 //===----------------------------------------------------------------------===//
1555 
1556 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1557  Value object, Value composite,
1558  ArrayRef<int32_t> indices) {
1559  auto indexAttr = builder.getI32ArrayAttr(indices);
1560  build(builder, state, composite.getType(), object, composite, indexAttr);
1561 }
1562 
1564  OperationState &state) {
1566  Type objectType, compositeType;
1567  Attribute indicesAttr;
1568  auto loc = parser.getCurrentLocation();
1569 
1570  return failure(
1571  parser.parseOperandList(operands, 2) ||
1572  parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1573  parser.parseColonType(objectType) ||
1574  parser.parseKeywordType("into", compositeType) ||
1575  parser.resolveOperands(operands, {objectType, compositeType}, loc,
1576  state.operands) ||
1577  parser.addTypesToList(compositeType, state.types));
1578 }
1579 
1580 static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
1581  auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
1582  auto objectType =
1583  getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
1584  compositeInsertOp.getLoc());
1585  if (!objectType)
1586  return failure();
1587 
1588  if (objectType != compositeInsertOp.object().getType()) {
1589  return compositeInsertOp.emitOpError("object operand type should be ")
1590  << objectType << ", but found "
1591  << compositeInsertOp.object().getType();
1592  }
1593 
1594  if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
1595  return compositeInsertOp.emitOpError("result type should be the same as "
1596  "the composite type, but found ")
1597  << compositeInsertOp.composite().getType() << " vs "
1598  << compositeInsertOp.getType();
1599  }
1600 
1601  return success();
1602 }
1603 
1604 static void print(spirv::CompositeInsertOp compositeInsertOp,
1605  OpAsmPrinter &printer) {
1606  printer << " " << compositeInsertOp.object() << ", "
1607  << compositeInsertOp.composite() << compositeInsertOp.indices()
1608  << " : " << compositeInsertOp.object().getType() << " into "
1609  << compositeInsertOp.composite().getType();
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // spv.Constant
1614 //===----------------------------------------------------------------------===//
1615 
1617  Attribute value;
1618  if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1619  return failure();
1620 
1621  Type type = value.getType();
1622  if (type.isa<NoneType, TensorType>()) {
1623  if (parser.parseColonType(type))
1624  return failure();
1625  }
1626 
1627  return parser.addTypeToList(type, state.types);
1628 }
1629 
1630 static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
1631  printer << ' ' << constOp.value();
1632  if (constOp.getType().isa<spirv::ArrayType>())
1633  printer << " : " << constOp.getType();
1634 }
1635 
1636 static LogicalResult verify(spirv::ConstantOp constOp) {
1637  auto opType = constOp.getType();
1638  auto value = constOp.value();
1639  auto valueType = value.getType();
1640 
1641  // ODS already generates checks to make sure the result type is valid. We just
1642  // need to additionally check that the value's attribute type is consistent
1643  // with the result type.
1644  if (value.isa<IntegerAttr, FloatAttr>()) {
1645  if (valueType != opType)
1646  return constOp.emitOpError("result type (")
1647  << opType << ") does not match value type (" << valueType << ")";
1648  return success();
1649  }
1650  if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1651  if (valueType == opType)
1652  return success();
1653  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1654  auto shapedType = valueType.dyn_cast<ShapedType>();
1655  if (!arrayType) {
1656  return constOp.emitOpError(
1657  "must have spv.array result type for array value");
1658  }
1659 
1660  int numElements = arrayType.getNumElements();
1661  auto opElemType = arrayType.getElementType();
1662  while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1663  numElements *= t.getNumElements();
1664  opElemType = t.getElementType();
1665  }
1666  if (!opElemType.isIntOrFloat())
1667  return constOp.emitOpError("only support nested array result type");
1668 
1669  auto valueElemType = shapedType.getElementType();
1670  if (valueElemType != opElemType) {
1671  return constOp.emitOpError("result element type (")
1672  << opElemType << ") does not match value element type ("
1673  << valueElemType << ")";
1674  }
1675 
1676  if (numElements != shapedType.getNumElements()) {
1677  return constOp.emitOpError("result number of elements (")
1678  << numElements << ") does not match value number of elements ("
1679  << shapedType.getNumElements() << ")";
1680  }
1681  return success();
1682  }
1683  if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
1684  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1685  if (!arrayType)
1686  return constOp.emitOpError(
1687  "must have spv.array result type for array value");
1688  Type elemType = arrayType.getElementType();
1689  for (Attribute element : attayAttr.getValue()) {
1690  if (element.getType() != elemType)
1691  return constOp.emitOpError("has array element whose type (")
1692  << element.getType()
1693  << ") does not match the result element type (" << elemType
1694  << ')';
1695  }
1696  return success();
1697  }
1698  return constOp.emitOpError("cannot have value of type ") << valueType;
1699 }
1700 
1701 bool spirv::ConstantOp::isBuildableWith(Type type) {
1702  // Must be valid SPIR-V type first.
1703  if (!type.isa<spirv::SPIRVType>())
1704  return false;
1705 
1706  if (isa<SPIRVDialect>(type.getDialect())) {
1707  // TODO: support constant struct
1708  return type.isa<spirv::ArrayType>();
1709  }
1710 
1711  return true;
1712 }
1713 
1714 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1715  OpBuilder &builder) {
1716  if (auto intType = type.dyn_cast<IntegerType>()) {
1717  unsigned width = intType.getWidth();
1718  if (width == 1)
1719  return builder.create<spirv::ConstantOp>(loc, type,
1720  builder.getBoolAttr(false));
1721  return builder.create<spirv::ConstantOp>(
1722  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1723  }
1724  if (auto floatType = type.dyn_cast<FloatType>()) {
1725  return builder.create<spirv::ConstantOp>(
1726  loc, type, builder.getFloatAttr(floatType, 0.0));
1727  }
1728  if (auto vectorType = type.dyn_cast<VectorType>()) {
1729  Type elemType = vectorType.getElementType();
1730  if (elemType.isa<IntegerType>()) {
1731  return builder.create<spirv::ConstantOp>(
1732  loc, type,
1734  IntegerAttr::get(elemType, 0.0).getValue()));
1735  }
1736  if (elemType.isa<FloatType>()) {
1737  return builder.create<spirv::ConstantOp>(
1738  loc, type,
1740  FloatAttr::get(elemType, 0.0).getValue()));
1741  }
1742  }
1743 
1744  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1745 }
1746 
1747 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1748  OpBuilder &builder) {
1749  if (auto intType = type.dyn_cast<IntegerType>()) {
1750  unsigned width = intType.getWidth();
1751  if (width == 1)
1752  return builder.create<spirv::ConstantOp>(loc, type,
1753  builder.getBoolAttr(true));
1754  return builder.create<spirv::ConstantOp>(
1755  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1756  }
1757  if (auto floatType = type.dyn_cast<FloatType>()) {
1758  return builder.create<spirv::ConstantOp>(
1759  loc, type, builder.getFloatAttr(floatType, 1.0));
1760  }
1761  if (auto vectorType = type.dyn_cast<VectorType>()) {
1762  Type elemType = vectorType.getElementType();
1763  if (elemType.isa<IntegerType>()) {
1764  return builder.create<spirv::ConstantOp>(
1765  loc, type,
1767  IntegerAttr::get(elemType, 1.0).getValue()));
1768  }
1769  if (elemType.isa<FloatType>()) {
1770  return builder.create<spirv::ConstantOp>(
1771  loc, type,
1773  FloatAttr::get(elemType, 1.0).getValue()));
1774  }
1775  }
1776 
1777  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1778 }
1779 
1780 void mlir::spirv::ConstantOp::getAsmResultNames(
1781  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1782  Type type = getType();
1783 
1784  SmallString<32> specialNameBuffer;
1785  llvm::raw_svector_ostream specialName(specialNameBuffer);
1786  specialName << "cst";
1787 
1788  IntegerType intTy = type.dyn_cast<IntegerType>();
1789 
1790  if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1791  if (intTy && intTy.getWidth() == 1) {
1792  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1793  }
1794 
1795  if (intTy.isSignless()) {
1796  specialName << intCst.getInt();
1797  } else {
1798  specialName << intCst.getSInt();
1799  }
1800  }
1801 
1802  if (intTy || type.isa<FloatType>()) {
1803  specialName << '_' << type;
1804  }
1805 
1806  if (auto vecType = type.dyn_cast<VectorType>()) {
1807  specialName << "_vec_";
1808  specialName << vecType.getDimSize(0);
1809 
1810  Type elementType = vecType.getElementType();
1811 
1812  if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1813  specialName << "x" << elementType;
1814  }
1815  }
1816 
1817  setNameFn(getResult(), specialName.str());
1818 }
1819 
1820 void mlir::spirv::AddressOfOp::getAsmResultNames(
1821  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1822  SmallString<32> specialNameBuffer;
1823  llvm::raw_svector_ostream specialName(specialNameBuffer);
1824  specialName << variable() << "_addr";
1825  setNameFn(getResult(), specialName.str());
1826 }
1827 
1828 //===----------------------------------------------------------------------===//
1829 // spv.EntryPoint
1830 //===----------------------------------------------------------------------===//
1831 
1832 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
1833  spirv::ExecutionModel executionModel,
1834  spirv::FuncOp function,
1835  ArrayRef<Attribute> interfaceVars) {
1836  build(builder, state,
1837  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
1838  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
1839 }
1840 
1842  OperationState &state) {
1843  spirv::ExecutionModel execModel;
1845  SmallVector<Type, 0> idTypes;
1846  SmallVector<Attribute, 4> interfaceVars;
1847 
1848  FlatSymbolRefAttr fn;
1849  if (parseEnumStrAttr(execModel, parser, state) ||
1850  parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
1851  return failure();
1852  }
1853 
1854  if (!parser.parseOptionalComma()) {
1855  // Parse the interface variables
1856  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
1857  // The name of the interface variable attribute isnt important
1858  FlatSymbolRefAttr var;
1859  NamedAttrList attrs;
1860  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
1861  return failure();
1862  interfaceVars.push_back(var);
1863  return success();
1864  }))
1865  return failure();
1866  }
1868  parser.getBuilder().getArrayAttr(interfaceVars));
1869  return success();
1870 }
1871 
1872 static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
1873  printer << " \"" << stringifyExecutionModel(entryPointOp.execution_model())
1874  << "\" ";
1875  printer.printSymbolName(entryPointOp.fn());
1876  auto interfaceVars = entryPointOp.interface().getValue();
1877  if (!interfaceVars.empty()) {
1878  printer << ", ";
1879  llvm::interleaveComma(interfaceVars, printer);
1880  }
1881 }
1882 
1883 static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
1884  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
1885  // verification.
1886  return success();
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // spv.ExecutionMode
1891 //===----------------------------------------------------------------------===//
1892 
1893 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
1894  spirv::FuncOp function,
1895  spirv::ExecutionMode executionMode,
1896  ArrayRef<int32_t> params) {
1897  build(builder, state, SymbolRefAttr::get(function),
1898  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
1899  builder.getI32ArrayAttr(params));
1900 }
1901 
1903  OperationState &state) {
1904  spirv::ExecutionMode execMode;
1905  Attribute fn;
1906  if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
1907  parseEnumStrAttr(execMode, parser, state)) {
1908  return failure();
1909  }
1910 
1911  SmallVector<int32_t, 4> values;
1912  Type i32Type = parser.getBuilder().getIntegerType(32);
1913  while (!parser.parseOptionalComma()) {
1914  NamedAttrList attr;
1915  Attribute value;
1916  if (parser.parseAttribute(value, i32Type, "value", attr)) {
1917  return failure();
1918  }
1919  values.push_back(value.cast<IntegerAttr>().getInt());
1920  }
1922  parser.getBuilder().getI32ArrayAttr(values));
1923  return success();
1924 }
1925 
1926 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
1927  printer << " ";
1928  printer.printSymbolName(execModeOp.fn());
1929  printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
1930  << "\"";
1931  auto values = execModeOp.values();
1932  if (values.empty())
1933  return;
1934  printer << ", ";
1935  llvm::interleaveComma(values, printer, [&](Attribute a) {
1936  printer << a.cast<IntegerAttr>().getInt();
1937  });
1938 }
1939 
1940 //===----------------------------------------------------------------------===//
1941 // spv.func
1942 //===----------------------------------------------------------------------===//
1943 
1946  SmallVector<NamedAttrList> argAttrs;
1947  SmallVector<NamedAttrList> resultAttrs;
1948  SmallVector<Type> argTypes;
1949  SmallVector<Type> resultTypes;
1950  SmallVector<Location> argLocations;
1951  auto &builder = parser.getBuilder();
1952 
1953  // Parse the name as a symbol.
1954  StringAttr nameAttr;
1955  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1956  state.attributes))
1957  return failure();
1958 
1959  // Parse the function signature.
1960  bool isVariadic = false;
1962  parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
1963  argLocations, isVariadic, resultTypes, resultAttrs))
1964  return failure();
1965 
1966  auto fnType = builder.getFunctionType(argTypes, resultTypes);
1968  TypeAttr::get(fnType));
1969 
1970  // Parse the optional function control keyword.
1971  spirv::FunctionControl fnControl;
1972  if (parseEnumStrAttr(fnControl, parser, state))
1973  return failure();
1974 
1975  // If additional attributes are present, parse them.
1976  if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1977  return failure();
1978 
1979  // Add the attributes to the function arguments.
1980  assert(argAttrs.size() == argTypes.size());
1981  assert(resultAttrs.size() == resultTypes.size());
1982  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
1983  resultAttrs);
1984 
1985  // Parse the optional function body.
1986  auto *body = state.addRegion();
1987  OptionalParseResult result = parser.parseOptionalRegion(
1988  *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1989  return failure(result.hasValue() && failed(*result));
1990 }
1991 
1992 static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
1993  // Print function name, signature, and control.
1994  printer << " ";
1995  printer.printSymbolName(fnOp.sym_name());
1996  auto fnType = fnOp.getType();
1998  printer, fnOp, fnType.getInputs(),
1999  /*isVariadic=*/false, fnType.getResults());
2000  printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
2001  << "\"";
2003  printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
2004  {spirv::attributeName<spirv::FunctionControl>()});
2005 
2006  // Print the body if this is not an external function.
2007  Region &body = fnOp.body();
2008  if (!body.empty()) {
2009  printer << ' ';
2010  printer.printRegion(body, /*printEntryBlockArgs=*/false,
2011  /*printBlockTerminators=*/true);
2012  }
2013 }
2014 
2015 LogicalResult spirv::FuncOp::verifyType() {
2016  auto type = getTypeAttr().getValue();
2017  if (!type.isa<FunctionType>())
2018  return emitOpError("requires '" + getTypeAttrName() +
2019  "' attribute of function type");
2020  if (getType().getNumResults() > 1)
2021  return emitOpError("cannot have more than one result");
2022  return success();
2023 }
2024 
2025 LogicalResult spirv::FuncOp::verifyBody() {
2026  FunctionType fnType = getType();
2027 
2028  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2029  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2030  if (fnType.getNumResults() != 0)
2031  return retOp.emitOpError("cannot be used in functions returning value");
2032  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2033  if (fnType.getNumResults() != 1)
2034  return retOp.emitOpError(
2035  "returns 1 value but enclosing function requires ")
2036  << fnType.getNumResults() << " results";
2037 
2038  auto retOperandType = retOp.value().getType();
2039  auto fnResultType = fnType.getResult(0);
2040  if (retOperandType != fnResultType)
2041  return retOp.emitOpError(" return value's type (")
2042  << retOperandType << ") mismatch with function's result type ("
2043  << fnResultType << ")";
2044  }
2045  return WalkResult::advance();
2046  });
2047 
2048  // TODO: verify other bits like linkage type.
2049 
2050  return failure(walkResult.wasInterrupted());
2051 }
2052 
2053 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2054  StringRef name, FunctionType type,
2055  spirv::FunctionControl control,
2056  ArrayRef<NamedAttribute> attrs) {
2058  builder.getStringAttr(name));
2059  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2060  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2061  builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
2062  state.attributes.append(attrs.begin(), attrs.end());
2063  state.addRegion();
2064 }
2065 
2066 // CallableOpInterface
2067 Region *spirv::FuncOp::getCallableRegion() {
2068  return isExternal() ? nullptr : &body();
2069 }
2070 
2071 // CallableOpInterface
2072 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2073  return getType().getResults();
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // spv.FunctionCall
2078 //===----------------------------------------------------------------------===//
2079 
2080 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
2081  auto fnName = functionCallOp.calleeAttr();
2082 
2083  auto funcOp =
2084  dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
2085  functionCallOp->getParentOp(), fnName));
2086  if (!funcOp) {
2087  return functionCallOp.emitOpError("callee function '")
2088  << fnName.getValue() << "' not found in nearest symbol table";
2089  }
2090 
2091  auto functionType = funcOp.getType();
2092 
2093  if (functionCallOp.getNumResults() > 1) {
2094  return functionCallOp.emitOpError(
2095  "expected callee function to have 0 or 1 result, but provided ")
2096  << functionCallOp.getNumResults();
2097  }
2098 
2099  if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
2100  return functionCallOp.emitOpError(
2101  "has incorrect number of operands for callee: expected ")
2102  << functionType.getNumInputs() << ", but provided "
2103  << functionCallOp.getNumOperands();
2104  }
2105 
2106  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2107  if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
2108  return functionCallOp.emitOpError(
2109  "operand type mismatch: expected operand type ")
2110  << functionType.getInput(i) << ", but provided "
2111  << functionCallOp.getOperand(i).getType() << " for operand number "
2112  << i;
2113  }
2114  }
2115 
2116  if (functionType.getNumResults() != functionCallOp.getNumResults()) {
2117  return functionCallOp.emitOpError(
2118  "has incorrect number of results has for callee: expected ")
2119  << functionType.getNumResults() << ", but provided "
2120  << functionCallOp.getNumResults();
2121  }
2122 
2123  if (functionCallOp.getNumResults() &&
2124  (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
2125  return functionCallOp.emitOpError("result type mismatch: expected ")
2126  << functionType.getResult(0) << ", but provided "
2127  << functionCallOp.getResult(0).getType();
2128  }
2129 
2130  return success();
2131 }
2132 
2133 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2134  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2135 }
2136 
2137 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2138  return arguments();
2139 }
2140 
2141 //===----------------------------------------------------------------------===//
2142 // spv.GlobalVariable
2143 //===----------------------------------------------------------------------===//
2144 
2145 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2146  Type type, StringRef name,
2147  unsigned descriptorSet, unsigned binding) {
2148  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2149  state.addAttribute(
2150  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2151  builder.getI32IntegerAttr(descriptorSet));
2152  state.addAttribute(
2153  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2154  builder.getI32IntegerAttr(binding));
2155 }
2156 
2157 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2158  Type type, StringRef name,
2159  spirv::BuiltIn builtin) {
2160  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2161  state.addAttribute(
2162  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2163  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2164 }
2165 
2167  OperationState &state) {
2168  // Parse variable name.
2169  StringAttr nameAttr;
2170  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2171  state.attributes)) {
2172  return failure();
2173  }
2174 
2175  // Parse optional initializer
2177  FlatSymbolRefAttr initSymbol;
2178  if (parser.parseLParen() ||
2179  parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2180  state.attributes) ||
2181  parser.parseRParen())
2182  return failure();
2183  }
2184 
2185  if (parseVariableDecorations(parser, state)) {
2186  return failure();
2187  }
2188 
2189  Type type;
2190  auto loc = parser.getCurrentLocation();
2191  if (parser.parseColonType(type)) {
2192  return failure();
2193  }
2194  if (!type.isa<spirv::PointerType>()) {
2195  return parser.emitError(loc, "expected spv.ptr type");
2196  }
2197  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
2198 
2199  return success();
2200 }
2201 
2202 static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
2203  auto *op = varOp.getOperation();
2204  SmallVector<StringRef, 4> elidedAttrs{
2205  spirv::attributeName<spirv::StorageClass>()};
2206 
2207  // Print variable name.
2208  printer << ' ';
2209  printer.printSymbolName(varOp.sym_name());
2210  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2211 
2212  // Print optional initializer
2213  if (auto initializer = varOp.initializer()) {
2214  printer << " " << kInitializerAttrName << '(';
2215  printer.printSymbolName(initializer.getValue());
2216  printer << ')';
2217  elidedAttrs.push_back(kInitializerAttrName);
2218  }
2219 
2220  elidedAttrs.push_back(kTypeAttrName);
2221  printVariableDecorations(op, printer, elidedAttrs);
2222  printer << " : " << varOp.type();
2223 }
2224 
2225 static LogicalResult verify(spirv::GlobalVariableOp varOp) {
2226  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2227  // object. It cannot be Generic. It must be the same as the Storage Class
2228  // operand of the Result Type."
2229  // Also, Function storage class is reserved by spv.Variable.
2230  auto storageClass = varOp.storageClass();
2231  if (storageClass == spirv::StorageClass::Generic ||
2232  storageClass == spirv::StorageClass::Function) {
2233  return varOp.emitOpError("storage class cannot be '")
2234  << stringifyStorageClass(storageClass) << "'";
2235  }
2236 
2237  if (auto init =
2238  varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2240  varOp->getParentOp(), init.getAttr());
2241  // TODO: Currently only variable initialization with specialization
2242  // constants and other variables is supported. They could be normal
2243  // constants in the module scope as well.
2244  if (!initOp ||
2245  !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2246  return varOp.emitOpError("initializer must be result of a "
2247  "spv.SpecConstant or spv.GlobalVariable op");
2248  }
2249  }
2250 
2251  return success();
2252 }
2253 
2254 //===----------------------------------------------------------------------===//
2255 // spv.GroupBroadcast
2256 //===----------------------------------------------------------------------===//
2257 
2258 static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
2259  spirv::Scope scope = broadcastOp.execution_scope();
2260  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2261  return broadcastOp.emitOpError(
2262  "execution scope must be 'Workgroup' or 'Subgroup'");
2263 
2264  if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
2265  if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
2266  return broadcastOp.emitOpError("localid is a vector and can be with only "
2267  " 2 or 3 components, actual number is ")
2268  << localIdTy.getNumElements();
2269 
2270  return success();
2271 }
2272 
2273 //===----------------------------------------------------------------------===//
2274 // spv.GroupNonUniformBallotOp
2275 //===----------------------------------------------------------------------===//
2276 
2277 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
2278  spirv::Scope scope = ballotOp.execution_scope();
2279  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2280  return ballotOp.emitOpError(
2281  "execution scope must be 'Workgroup' or 'Subgroup'");
2282 
2283  return success();
2284 }
2285 
2286 //===----------------------------------------------------------------------===//
2287 // spv.GroupNonUniformBroadcast
2288 //===----------------------------------------------------------------------===//
2289 
2290 static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
2291  spirv::Scope scope = broadcastOp.execution_scope();
2292  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2293  return broadcastOp.emitOpError(
2294  "execution scope must be 'Workgroup' or 'Subgroup'");
2295 
2296  // SPIR-V spec: "Before version 1.5, Id must come from a
2297  // constant instruction.
2298  auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
2299  if (auto spirvModule = broadcastOp->getParentOfType<spirv::ModuleOp>())
2300  targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2301 
2302  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2303  auto *idOp = broadcastOp.id().getDefiningOp();
2304  if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2305  spirv::ReferenceOfOp>(idOp)) // for spec constant
2306  return broadcastOp.emitOpError("id must be the result of a constant op");
2307  }
2308 
2309  return success();
2310 }
2311 
2312 //===----------------------------------------------------------------------===//
2313 // spv.SubgroupBlockReadINTEL
2314 //===----------------------------------------------------------------------===//
2315 
2317  OperationState &state) {
2318  // Parse the storage class specification
2319  spirv::StorageClass storageClass;
2320  OpAsmParser::OperandType ptrInfo;
2321  Type elementType;
2322  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2323  parser.parseColon() || parser.parseType(elementType)) {
2324  return failure();
2325  }
2326 
2327  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2328  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2329  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2330 
2331  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2332  return failure();
2333  }
2334 
2335  state.addTypes(elementType);
2336  return success();
2337 }
2338 
2339 static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
2340  OpAsmPrinter &printer) {
2341  SmallVector<StringRef, 4> elidedAttrs;
2342  printer << " " << blockReadOp.ptr();
2343  printer << " : " << blockReadOp.getType();
2344 }
2345 
2346 static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
2347  if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
2348  blockReadOp.value())))
2349  return failure();
2350 
2351  return success();
2352 }
2353 
2354 //===----------------------------------------------------------------------===//
2355 // spv.SubgroupBlockWriteINTEL
2356 //===----------------------------------------------------------------------===//
2357 
2359  OperationState &state) {
2360  // Parse the storage class specification
2361  spirv::StorageClass storageClass;
2363  auto loc = parser.getCurrentLocation();
2364  Type elementType;
2365  if (parseEnumStrAttr(storageClass, parser) ||
2366  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2367  parser.parseType(elementType)) {
2368  return failure();
2369  }
2370 
2371  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2372  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2373  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2374 
2375  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2376  state.operands)) {
2377  return failure();
2378  }
2379  return success();
2380 }
2381 
2382 static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
2383  OpAsmPrinter &printer) {
2384  SmallVector<StringRef, 4> elidedAttrs;
2385  printer << " " << blockWriteOp.ptr() << ", " << blockWriteOp.value();
2386  printer << " : " << blockWriteOp.value().getType();
2387 }
2388 
2389 static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
2391  blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
2392  return failure();
2393 
2394  return success();
2395 }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // spv.GroupNonUniformElectOp
2399 //===----------------------------------------------------------------------===//
2400 
2401 static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
2402  spirv::Scope scope = groupOp.execution_scope();
2403  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2404  return groupOp.emitOpError(
2405  "execution scope must be 'Workgroup' or 'Subgroup'");
2406 
2407  return success();
2408 }
2409 
2410 //===----------------------------------------------------------------------===//
2411 // spv.LoadOp
2412 //===----------------------------------------------------------------------===//
2413 
2414 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2415  Value basePtr, MemoryAccessAttr memoryAccess,
2416  IntegerAttr alignment) {
2417  auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2418  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
2419  alignment);
2420 }
2421 
2423  // Parse the storage class specification
2424  spirv::StorageClass storageClass;
2425  OpAsmParser::OperandType ptrInfo;
2426  Type elementType;
2427  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2428  parseMemoryAccessAttributes(parser, state) ||
2429  parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2430  parser.parseType(elementType)) {
2431  return failure();
2432  }
2433 
2434  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2435  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2436  return failure();
2437  }
2438 
2439  state.addTypes(elementType);
2440  return success();
2441 }
2442 
2443 static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
2444  auto *op = loadOp.getOperation();
2445  SmallVector<StringRef, 4> elidedAttrs;
2446  StringRef sc = stringifyStorageClass(
2447  loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2448  printer << " \"" << sc << "\" " << loadOp.ptr();
2449 
2450  printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
2451 
2452  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2453  printer << " : " << loadOp.getType();
2454 }
2455 
2456 static LogicalResult verify(spirv::LoadOp loadOp) {
2457  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2458  // type with fixed size; i.e., it cannot be, nor include, any
2459  // OpTypeRuntimeArray types."
2460  if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
2461  loadOp.value()))) {
2462  return failure();
2463  }
2464  return verifyMemoryAccessAttribute(loadOp);
2465 }
2466 
2467 //===----------------------------------------------------------------------===//
2468 // spv.mlir.loop
2469 //===----------------------------------------------------------------------===//
2470 
2471 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2472  state.addAttribute("loop_control",
2473  builder.getI32IntegerAttr(
2474  static_cast<uint32_t>(spirv::LoopControl::None)));
2475  state.addRegion();
2476 }
2477 
2479  if (parseControlAttribute<spirv::LoopControl>(parser, state))
2480  return failure();
2481  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2482  /*argTypes=*/{});
2483 }
2484 
2485 static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
2486  auto *op = loopOp.getOperation();
2487 
2488  auto control = loopOp.loop_control();
2489  if (control != spirv::LoopControl::None)
2490  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2491  printer << ' ';
2492  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2493  /*printBlockTerminators=*/true);
2494 }
2495 
2496 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2497 /// given `dstBlock`.
2498 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2499  // Check that there is only one op in the `srcBlock`.
2500  if (!llvm::hasSingleElement(srcBlock))
2501  return false;
2502 
2503  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2504  return branchOp && branchOp.getSuccessor() == &dstBlock;
2505 }
2506 
2507 static LogicalResult verify(spirv::LoopOp loopOp) {
2508  auto *op = loopOp.getOperation();
2509 
2510  // We need to verify that the blocks follow the following layout:
2511  //
2512  // +-------------+
2513  // | entry block |
2514  // +-------------+
2515  // |
2516  // v
2517  // +-------------+
2518  // | loop header | <-----+
2519  // +-------------+ |
2520  // |
2521  // ... |
2522  // \ | / |
2523  // v |
2524  // +---------------+ |
2525  // | loop continue | -----+
2526  // +---------------+
2527  //
2528  // ...
2529  // \ | /
2530  // v
2531  // +-------------+
2532  // | merge block |
2533  // +-------------+
2534 
2535  auto &region = op->getRegion(0);
2536  // Allow empty region as a degenerated case, which can come from
2537  // optimizations.
2538  if (region.empty())
2539  return success();
2540 
2541  // The last block is the merge block.
2542  Block &merge = region.back();
2543  if (!isMergeBlock(merge))
2544  return loopOp.emitOpError(
2545  "last block must be the merge block with only one 'spv.mlir.merge' op");
2546 
2547  if (std::next(region.begin()) == region.end())
2548  return loopOp.emitOpError(
2549  "must have an entry block branching to the loop header block");
2550  // The first block is the entry block.
2551  Block &entry = region.front();
2552 
2553  if (std::next(region.begin(), 2) == region.end())
2554  return loopOp.emitOpError(
2555  "must have a loop header block branched from the entry block");
2556  // The second block is the loop header block.
2557  Block &header = *std::next(region.begin(), 1);
2558 
2559  if (!hasOneBranchOpTo(entry, header))
2560  return loopOp.emitOpError(
2561  "entry block must only have one 'spv.Branch' op to the second block");
2562 
2563  if (std::next(region.begin(), 3) == region.end())
2564  return loopOp.emitOpError(
2565  "requires a loop continue block branching to the loop header block");
2566  // The second to last block is the loop continue block.
2567  Block &cont = *std::prev(region.end(), 2);
2568 
2569  // Make sure that we have a branch from the loop continue block to the loop
2570  // header block.
2571  if (llvm::none_of(
2572  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
2573  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
2574  return loopOp.emitOpError("second to last block must be the loop continue "
2575  "block that branches to the loop header block");
2576 
2577  // Make sure that no other blocks (except the entry and loop continue block)
2578  // branches to the loop header block.
2579  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
2580  std::prev(region.end(), 2))) {
2581  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
2582  if (block.getSuccessor(i) == &header) {
2583  return loopOp.emitOpError("can only have the entry and loop continue "
2584  "block branching to the loop header block");
2585  }
2586  }
2587  }
2588 
2589  return success();
2590 }
2591 
2592 Block *spirv::LoopOp::getEntryBlock() {
2593  assert(!body().empty() && "op region should not be empty!");
2594  return &body().front();
2595 }
2596 
2597 Block *spirv::LoopOp::getHeaderBlock() {
2598  assert(!body().empty() && "op region should not be empty!");
2599  // The second block is the loop header block.
2600  return &*std::next(body().begin());
2601 }
2602 
2603 Block *spirv::LoopOp::getContinueBlock() {
2604  assert(!body().empty() && "op region should not be empty!");
2605  // The second to last block is the loop continue block.
2606  return &*std::prev(body().end(), 2);
2607 }
2608 
2609 Block *spirv::LoopOp::getMergeBlock() {
2610  assert(!body().empty() && "op region should not be empty!");
2611  // The last block is the loop merge block.
2612  return &body().back();
2613 }
2614 
2615 void spirv::LoopOp::addEntryAndMergeBlock() {
2616  assert(body().empty() && "entry and merge block already exist");
2617  body().push_back(new Block());
2618  auto *mergeBlock = new Block();
2619  body().push_back(mergeBlock);
2620  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2621 
2622  // Add a spv.mlir.merge op into the merge block.
2623  builder.create<spirv::MergeOp>(getLoc());
2624 }
2625 
2626 //===----------------------------------------------------------------------===//
2627 // spv.mlir.merge
2628 //===----------------------------------------------------------------------===//
2629 
2630 static LogicalResult verify(spirv::MergeOp mergeOp) {
2631  auto *parentOp = mergeOp->getParentOp();
2632  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
2633  return mergeOp.emitOpError(
2634  "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
2635 
2636  Block &parentLastBlock = mergeOp->getParentRegion()->back();
2637  if (mergeOp.getOperation() != parentLastBlock.getTerminator())
2638  return mergeOp.emitOpError("can only be used in the last block of "
2639  "'spv.mlir.selection' or 'spv.mlir.loop'");
2640  return success();
2641 }
2642 
2643 //===----------------------------------------------------------------------===//
2644 // spv.module
2645 //===----------------------------------------------------------------------===//
2646 
2647 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2648  Optional<StringRef> name) {
2649  OpBuilder::InsertionGuard guard(builder);
2650  builder.createBlock(state.addRegion());
2651  if (name) {
2653  builder.getStringAttr(*name));
2654  }
2655 }
2656 
2657 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2658  spirv::AddressingModel addressingModel,
2659  spirv::MemoryModel memoryModel,
2660  Optional<VerCapExtAttr> vceTriple,
2661  Optional<StringRef> name) {
2662  state.addAttribute(
2663  "addressing_model",
2664  builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
2665  state.addAttribute("memory_model", builder.getI32IntegerAttr(
2666  static_cast<int32_t>(memoryModel)));
2667  OpBuilder::InsertionGuard guard(builder);
2668  builder.createBlock(state.addRegion());
2669  if (vceTriple)
2670  state.addAttribute(getVCETripleAttrName(), *vceTriple);
2671  if (name)
2673  builder.getStringAttr(*name));
2674 }
2675 
2677  Region *body = state.addRegion();
2678 
2679  // If the name is present, parse it.
2680  StringAttr nameAttr;
2682  state.attributes);
2683 
2684  // Parse attributes
2685  spirv::AddressingModel addrModel;
2686  spirv::MemoryModel memoryModel;
2687  if (parseEnumKeywordAttr(addrModel, parser, state) ||
2688  parseEnumKeywordAttr(memoryModel, parser, state))
2689  return failure();
2690 
2691  if (succeeded(parser.parseOptionalKeyword("requires"))) {
2692  spirv::VerCapExtAttr vceTriple;
2693  if (parser.parseAttribute(vceTriple,
2694  spirv::ModuleOp::getVCETripleAttrName(),
2695  state.attributes))
2696  return failure();
2697  }
2698 
2699  if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2700  return failure();
2701 
2702  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2703  return failure();
2704 
2705  // Make sure we have at least one block.
2706  if (body->empty())
2707  body->push_back(new Block());
2708 
2709  return success();
2710 }
2711 
2712 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
2713  if (Optional<StringRef> name = moduleOp.getName()) {
2714  printer << ' ';
2715  printer.printSymbolName(*name);
2716  }
2717 
2718  SmallVector<StringRef, 2> elidedAttrs;
2719 
2720  printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
2721  << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
2722  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
2723  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
2724  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
2726 
2727  if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
2728  printer << " requires " << *triple;
2729  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
2730  }
2731 
2732  printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
2733  printer << ' ';
2734  printer.printRegion(moduleOp.getRegion());
2735 }
2736 
2737 static LogicalResult verify(spirv::ModuleOp moduleOp) {
2738  auto &op = *moduleOp.getOperation();
2739  auto *dialect = op.getDialect();
2741  entryPoints;
2742  SymbolTable table(moduleOp);
2743 
2744  for (auto &op : *moduleOp.getBody()) {
2745  if (op.getDialect() != dialect)
2746  return op.emitError("'spv.module' can only contain spv.* ops");
2747 
2748  // For EntryPoint op, check that the function and execution model is not
2749  // duplicated in EntryPointOps. Also verify that the interface specified
2750  // comes from globalVariables here to make this check cheaper.
2751  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
2752  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
2753  if (!funcOp) {
2754  return entryPointOp.emitError("function '")
2755  << entryPointOp.fn() << "' not found in 'spv.module'";
2756  }
2757  if (auto interface = entryPointOp.interface()) {
2758  for (Attribute varRef : interface) {
2759  auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
2760  if (!varSymRef) {
2761  return entryPointOp.emitError(
2762  "expected symbol reference for interface "
2763  "specification instead of '")
2764  << varRef;
2765  }
2766  auto variableOp =
2767  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
2768  if (!variableOp) {
2769  return entryPointOp.emitError("expected spv.GlobalVariable "
2770  "symbol reference instead of'")
2771  << varSymRef << "'";
2772  }
2773  }
2774  }
2775 
2776  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
2777  funcOp, entryPointOp.execution_model());
2778  auto entryPtIt = entryPoints.find(key);
2779  if (entryPtIt != entryPoints.end()) {
2780  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
2781  }
2782  entryPoints[key] = entryPointOp;
2783  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
2784  if (funcOp.isExternal())
2785  return op.emitError("'spv.module' cannot contain external functions");
2786 
2787  // TODO: move this check to spv.func.
2788  for (auto &block : funcOp)
2789  for (auto &op : block) {
2790  if (op.getDialect() != dialect)
2791  return op.emitError(
2792  "functions in 'spv.module' can only contain spv.* ops");
2793  }
2794  }
2795  }
2796 
2797  return success();
2798 }
2799 
2800 //===----------------------------------------------------------------------===//
2801 // spv.mlir.referenceof
2802 //===----------------------------------------------------------------------===//
2803 
2804 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
2805  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
2806  referenceOfOp->getParentOp(), referenceOfOp.spec_constAttr());
2807  Type constType;
2808 
2809  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
2810  if (specConstOp)
2811  constType = specConstOp.default_value().getType();
2812 
2813  auto specConstCompositeOp =
2814  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
2815  if (specConstCompositeOp)
2816  constType = specConstCompositeOp.type();
2817 
2818  if (!specConstOp && !specConstCompositeOp)
2819  return referenceOfOp.emitOpError(
2820  "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
2821 
2822  if (referenceOfOp.reference().getType() != constType)
2823  return referenceOfOp.emitOpError("result type mismatch with the referenced "
2824  "specialization constant's type");
2825 
2826  return success();
2827 }
2828 
2829 //===----------------------------------------------------------------------===//
2830 // spv.Return
2831 //===----------------------------------------------------------------------===//
2832 
2833 static LogicalResult verify(spirv::ReturnOp returnOp) {
2834  // Verification is performed in spv.func op.
2835  return success();
2836 }
2837 
2838 //===----------------------------------------------------------------------===//
2839 // spv.ReturnValue
2840 //===----------------------------------------------------------------------===//
2841 
2842 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
2843  // Verification is performed in spv.func op.
2844  return success();
2845 }
2846 
2847 //===----------------------------------------------------------------------===//
2848 // spv.Select
2849 //===----------------------------------------------------------------------===//
2850 
2851 static LogicalResult verify(spirv::SelectOp op) {
2852  if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
2853  auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
2854  if (!resultVectorTy) {
2855  return op.emitOpError("result expected to be of vector type when "
2856  "condition is of vector type");
2857  }
2858  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
2859  return op.emitOpError("result should have the same number of elements as "
2860  "the condition when condition is of vector type");
2861  }
2862  }
2863  return success();
2864 }
2865 
2866 //===----------------------------------------------------------------------===//
2867 // spv.mlir.selection
2868 //===----------------------------------------------------------------------===//
2869 
2871  OperationState &state) {
2872  if (parseControlAttribute<spirv::SelectionControl>(parser, state))
2873  return failure();
2874  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2875  /*argTypes=*/{});
2876 }
2877 
2878 static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
2879  auto *op = selectionOp.getOperation();
2880  auto control = selectionOp.selection_control();
2881  if (control != spirv::SelectionControl::None)
2882  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
2883  printer << ' ';
2884  printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2885  /*printBlockTerminators=*/true);
2886 }
2887 
2888 static LogicalResult verify(spirv::SelectionOp selectionOp) {
2889  auto *op = selectionOp.getOperation();
2890 
2891  // We need to verify that the blocks follow the following layout:
2892  //
2893  // +--------------+
2894  // | header block |
2895  // +--------------+
2896  // / | \
2897  // ...
2898  //
2899  //
2900  // +---------+ +---------+ +---------+
2901  // | case #0 | | case #1 | | case #2 | ...
2902  // +---------+ +---------+ +---------+
2903  //
2904  //
2905  // ...
2906  // \ | /
2907  // v
2908  // +-------------+
2909  // | merge block |
2910  // +-------------+
2911 
2912  auto &region = op->getRegion(0);
2913  // Allow empty region as a degenerated case, which can come from
2914  // optimizations.
2915  if (region.empty())
2916  return success();
2917 
2918  // The last block is the merge block.
2919  if (!isMergeBlock(region.back()))
2920  return selectionOp.emitOpError(
2921  "last block must be the merge block with only one 'spv.mlir.merge' op");
2922 
2923  if (std::next(region.begin()) == region.end())
2924  return selectionOp.emitOpError("must have a selection header block");
2925 
2926  return success();
2927 }
2928 
2929 Block *spirv::SelectionOp::getHeaderBlock() {
2930  assert(!body().empty() && "op region should not be empty!");
2931  // The first block is the loop header block.
2932  return &body().front();
2933 }
2934 
2935 Block *spirv::SelectionOp::getMergeBlock() {
2936  assert(!body().empty() && "op region should not be empty!");
2937  // The last block is the loop merge block.
2938  return &body().back();
2939 }
2940 
2941 void spirv::SelectionOp::addMergeBlock() {
2942  assert(body().empty() && "entry and merge block already exist");
2943  auto *mergeBlock = new Block();
2944  body().push_back(mergeBlock);
2945  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2946 
2947  // Add a spv.mlir.merge op into the merge block.
2948  builder.create<spirv::MergeOp>(getLoc());
2949 }
2950 
2951 spirv::SelectionOp spirv::SelectionOp::createIfThen(
2952  Location loc, Value condition,
2953  function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
2954  auto selectionOp =
2955  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
2956 
2957  selectionOp.addMergeBlock();
2958  Block *mergeBlock = selectionOp.getMergeBlock();
2959  Block *thenBlock = nullptr;
2960 
2961  // Build the "then" block.
2962  {
2963  OpBuilder::InsertionGuard guard(builder);
2964  thenBlock = builder.createBlock(mergeBlock);
2965  thenBody(builder);
2966  builder.create<spirv::BranchOp>(loc, mergeBlock);
2967  }
2968 
2969  // Build the header block.
2970  {
2971  OpBuilder::InsertionGuard guard(builder);
2972  builder.createBlock(thenBlock);
2973  builder.create<spirv::BranchConditionalOp>(
2974  loc, condition, thenBlock,
2975  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
2976  /*falseArguments=*/ArrayRef<Value>());
2977  }
2978 
2979  return selectionOp;
2980 }
2981 
2982 //===----------------------------------------------------------------------===//
2983 // spv.SpecConstant
2984 //===----------------------------------------------------------------------===//
2985 
2987  OperationState &state) {
2988  StringAttr nameAttr;
2989  Attribute valueAttr;
2990 
2991  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2992  state.attributes))
2993  return failure();
2994 
2995  // Parse optional spec_id.
2997  IntegerAttr specIdAttr;
2998  if (parser.parseLParen() ||
2999  parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
3000  parser.parseRParen())
3001  return failure();
3002  }
3003 
3004  if (parser.parseEqual() ||
3005  parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
3006  return failure();
3007 
3008  return success();
3009 }
3010 
3011 static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
3012  printer << ' ';
3013  printer.printSymbolName(constOp.sym_name());
3014  if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3015  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3016  printer << " = " << constOp.default_value();
3017 }
3018 
3019 static LogicalResult verify(spirv::SpecConstantOp constOp) {
3020  if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3021  if (specID.getValue().isNegative())
3022  return constOp.emitOpError("SpecId cannot be negative");
3023 
3024  auto value = constOp.default_value();
3025  if (value.isa<IntegerAttr, FloatAttr>()) {
3026  // Make sure bitwidth is allowed.
3027  if (!value.getType().isa<spirv::SPIRVType>())
3028  return constOp.emitOpError("default value bitwidth disallowed");
3029  return success();
3030  }
3031  return constOp.emitOpError(
3032  "default value can only be a bool, integer, or float scalar");
3033 }
3034 
3035 //===----------------------------------------------------------------------===//
3036 // spv.StoreOp
3037 //===----------------------------------------------------------------------===//
3038 
3040  // Parse the storage class specification
3041  spirv::StorageClass storageClass;
3043  auto loc = parser.getCurrentLocation();
3044  Type elementType;
3045  if (parseEnumStrAttr(storageClass, parser) ||
3046  parser.parseOperandList(operandInfo, 2) ||
3047  parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3048  parser.parseType(elementType)) {
3049  return failure();
3050  }
3051 
3052  auto ptrType = spirv::PointerType::get(elementType, storageClass);
3053  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3054  state.operands)) {
3055  return failure();
3056  }
3057  return success();
3058 }
3059 
3060 static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
3061  auto *op = storeOp.getOperation();
3062  SmallVector<StringRef, 4> elidedAttrs;
3063  StringRef sc = stringifyStorageClass(
3064  storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
3065  printer << " \"" << sc << "\" " << storeOp.ptr() << ", " << storeOp.value();
3066 
3067  printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
3068 
3069  printer << " : " << storeOp.value().getType();
3070  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3071 }
3072 
3073 static LogicalResult verify(spirv::StoreOp storeOp) {
3074  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3075  // OpTypePointer whose Type operand is the same as the type of Object."
3076  if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
3077  storeOp.value()))) {
3078  return failure();
3079  }
3080  return verifyMemoryAccessAttribute(storeOp);
3081 }
3082 
3083 //===----------------------------------------------------------------------===//
3084 // spv.Unreachable
3085 //===----------------------------------------------------------------------===//
3086 
3087 static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
3088  auto *op = unreachableOp.getOperation();
3089  auto *block = op->getBlock();
3090  // Fast track: if this is in entry block, its invalid. Otherwise, if no
3091  // predecessors, it's valid.
3092  if (block->isEntryBlock())
3093  return unreachableOp.emitOpError("cannot be used in reachable block");
3094  if (block->hasNoPredecessors())
3095  return success();
3096 
3097  // TODO: further verification needs to analyze reachability from
3098  // the entry block.
3099 
3100  return success();
3101 }
3102 
3103 //===----------------------------------------------------------------------===//
3104 // spv.Variable
3105 //===----------------------------------------------------------------------===//
3106 
3108  // Parse optional initializer
3110  if (succeeded(parser.parseOptionalKeyword("init"))) {
3111  initInfo = OpAsmParser::OperandType();
3112  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3113  parser.parseRParen())
3114  return failure();
3115  }
3116 
3117  if (parseVariableDecorations(parser, state)) {
3118  return failure();
3119  }
3120 
3121  // Parse result pointer type
3122  Type type;
3123  if (parser.parseColon())
3124  return failure();
3125  auto loc = parser.getCurrentLocation();
3126  if (parser.parseType(type))
3127  return failure();
3128 
3129  auto ptrType = type.dyn_cast<spirv::PointerType>();
3130  if (!ptrType)
3131  return parser.emitError(loc, "expected spv.ptr type");
3132  state.addTypes(ptrType);
3133 
3134  // Resolve the initializer operand
3135  if (initInfo) {
3136  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3137  state.operands))
3138  return failure();
3139  }
3140 
3141  auto attr = parser.getBuilder().getI32IntegerAttr(
3142  llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
3143  state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3144 
3145  return success();
3146 }
3147 
3148 static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
3149  SmallVector<StringRef, 4> elidedAttrs{
3150  spirv::attributeName<spirv::StorageClass>()};
3151  // Print optional initializer
3152  if (varOp.getNumOperands() != 0)
3153  printer << " init(" << varOp.initializer() << ")";
3154 
3155  printVariableDecorations(varOp, printer, elidedAttrs);
3156  printer << " : " << varOp.getType();
3157 }
3158 
3159 static LogicalResult verify(spirv::VariableOp varOp) {
3160  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3161  // object. It cannot be Generic. It must be the same as the Storage Class
3162  // operand of the Result Type."
3163  if (varOp.storage_class() != spirv::StorageClass::Function) {
3164  return varOp.emitOpError(
3165  "can only be used to model function-level variables. Use "
3166  "spv.GlobalVariable for module-level variables.");
3167  }
3168 
3169  auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
3170  if (varOp.storage_class() != pointerType.getStorageClass())
3171  return varOp.emitOpError(
3172  "storage class must match result pointer's storage class");
3173 
3174  if (varOp.getNumOperands() != 0) {
3175  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3176  // a global (module scope) OpVariable instruction".
3177  auto *initOp = varOp.getOperand(0).getDefiningOp();
3178  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3179  spirv::ReferenceOfOp, // for spec constant
3180  spirv::AddressOfOp>(initOp))
3181  return varOp.emitOpError("initializer must be the result of a "
3182  "constant or spv.GlobalVariable op");
3183  }
3184 
3185  // TODO: generate these strings using ODS.
3186  auto *op = varOp.getOperation();
3187  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3188  stringifyDecoration(spirv::Decoration::DescriptorSet));
3189  auto bindingName = llvm::convertToSnakeFromCamelCase(
3190  stringifyDecoration(spirv::Decoration::Binding));
3191  auto builtInName = llvm::convertToSnakeFromCamelCase(
3192  stringifyDecoration(spirv::Decoration::BuiltIn));
3193 
3194  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3195  if (op->getAttr(attr))
3196  return varOp.emitOpError("cannot have '")
3197  << attr << "' attribute (only allowed in spv.GlobalVariable)";
3198  }
3199 
3200  return success();
3201 }
3202 
3203 //===----------------------------------------------------------------------===//
3204 // spv.VectorShuffle
3205 //===----------------------------------------------------------------------===//
3206 
3207 static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) {
3208  VectorType resultType = shuffleOp.getType().cast<VectorType>();
3209 
3210  size_t numResultElements = resultType.getNumElements();
3211  if (numResultElements != shuffleOp.components().size())
3212  return shuffleOp.emitOpError("result type element count (")
3213  << numResultElements
3214  << ") mismatch with the number of component selectors ("
3215  << shuffleOp.components().size() << ")";
3216 
3217  size_t totalSrcElements =
3218  shuffleOp.vector1().getType().cast<VectorType>().getNumElements() +
3219  shuffleOp.vector2().getType().cast<VectorType>().getNumElements();
3220 
3221  for (const auto &selector :
3222  shuffleOp.components().getAsValueRange<IntegerAttr>()) {
3223  uint32_t index = selector.getZExtValue();
3224  if (index >= totalSrcElements &&
3225  index != std::numeric_limits<uint32_t>().max())
3226  return shuffleOp.emitOpError("component selector ")
3227  << index << " out of range: expected to be in [0, "
3228  << totalSrcElements << ") or 0xffffffff";
3229  }
3230  return success();
3231 }
3232 
3233 //===----------------------------------------------------------------------===//
3234 // spv.CooperativeMatrixLoadNV
3235 //===----------------------------------------------------------------------===//
3236 
3238  OperationState &state) {
3240  Type strideType = parser.getBuilder().getIntegerType(32);
3241  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3242  Type ptrType;
3243  Type elementType;
3244  if (parser.parseOperandList(operandInfo, 3) ||
3245  parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3246  parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3247  return failure();
3248  }
3249  if (parser.resolveOperands(operandInfo,
3250  {ptrType, strideType, columnMajorType},
3251  parser.getNameLoc(), state.operands)) {
3252  return failure();
3253  }
3254 
3255  state.addTypes(elementType);
3256  return success();
3257 }
3258 
3259 static void print(spirv::CooperativeMatrixLoadNVOp m, OpAsmPrinter &printer) {
3260  printer << " " << m.pointer() << ", " << m.stride() << ", "
3261  << m.columnmajor();
3262  // Print optional memory access attribute.
3263  if (auto memAccess = m.memory_access())
3264  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3265  printer << " : " << m.pointer().getType() << " as " << m.getType();
3266 }
3267 
3269  Type coopMatrix) {
3270  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3271  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3272  return op->emitError(
3273  "Pointer must point to a scalar or vector type but provided ")
3274  << pointeeType;
3275  spirv::StorageClass storage =
3276  pointer.cast<spirv::PointerType>().getStorageClass();
3277  if (storage != spirv::StorageClass::Workgroup &&
3278  storage != spirv::StorageClass::StorageBuffer &&
3279  storage != spirv::StorageClass::PhysicalStorageBuffer)
3280  return op->emitError(
3281  "Pointer storage class must be Workgroup, StorageBuffer or "
3282  "PhysicalStorageBufferEXT but provided ")
3283  << stringifyStorageClass(storage);
3284  return success();
3285 }
3286 
3287 //===----------------------------------------------------------------------===//
3288 // spv.CooperativeMatrixStoreNV
3289 //===----------------------------------------------------------------------===//
3290 
3292  OperationState &state) {
3294  Type strideType = parser.getBuilder().getIntegerType(32);
3295  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3296  Type ptrType;
3297  Type elementType;
3298  if (parser.parseOperandList(operandInfo, 4) ||
3299  parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3300  parser.parseType(ptrType) || parser.parseComma() ||
3301  parser.parseType(elementType)) {
3302  return failure();
3303  }
3304  if (parser.resolveOperands(
3305  operandInfo, {ptrType, elementType, strideType, columnMajorType},
3306  parser.getNameLoc(), state.operands)) {
3307  return failure();
3308  }
3309 
3310  return success();
3311 }
3312 
3313 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
3314  OpAsmPrinter &printer) {
3315  printer << " " << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
3316  << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
3317  // Print optional memory access attribute.
3318  if (auto memAccess = coopMatrix.memory_access())
3319  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3320  printer << " : " << coopMatrix.pointer().getType() << ", "
3321  << coopMatrix.getOperand(1).getType();
3322 }
3323 
3324 //===----------------------------------------------------------------------===//
3325 // spv.CooperativeMatrixMulAddNV
3326 //===----------------------------------------------------------------------===//
3327 
3328 static LogicalResult
3329 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3330  if (op.c().getType() != op.result().getType())
3331  return op.emitOpError("result and third operand must have the same type");
3332  auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3333  auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3334  auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3335  auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3336  if (typeA.getRows() != typeR.getRows() ||
3337  typeA.getColumns() != typeB.getRows() ||
3338  typeB.getColumns() != typeR.getColumns())
3339  return op.emitOpError("matrix size must match");
3340  if (typeR.getScope() != typeA.getScope() ||
3341  typeR.getScope() != typeB.getScope() ||
3342  typeR.getScope() != typeC.getScope())
3343  return op.emitOpError("matrix scope must match");
3344  if (typeA.getElementType() != typeB.getElementType() ||
3345  typeR.getElementType() != typeC.getElementType())
3346  return op.emitOpError("matrix element type must match");
3347  return success();
3348 }
3349 
3350 //===----------------------------------------------------------------------===//
3351 // spv.MatrixTimesScalar
3352 //===----------------------------------------------------------------------===//
3353 
3354 static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
3355  // We already checked that result and matrix are both of matrix type in the
3356  // auto-generated verify method.
3357 
3358  auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3359  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3360 
3361  // Check that the scalar type is the same as the matrix element type.
3362  if (op.scalar().getType() != inputMatrix.getElementType())
3363  return op.emitError("input matrix components' type and scaling value must "
3364  "have the same type");
3365 
3366  // Note that the next three checks could be done using the AllTypesMatch
3367  // trait in the Op definition file but it generates a vague error message.
3368 
3369  // Check that the input and result matrices have the same columns' count
3370  if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3371  return op.emitError("input and result matrices must have the same "
3372  "number of columns");
3373 
3374  // Check that the input and result matrices' have the same rows count
3375  if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3376  return op.emitError("input and result matrices' columns must have "
3377  "the same size");
3378 
3379  // Check that the input and result matrices' have the same component type
3380  if (inputMatrix.getElementType() != resultMatrix.getElementType())
3381  return op.emitError("input and result matrices' columns must have "
3382  "the same component type");
3383 
3384  return success();
3385 }
3386 
3387 //===----------------------------------------------------------------------===//
3388 // spv.CopyMemory
3389 //===----------------------------------------------------------------------===//
3390 
3391 static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
3392  auto *op = copyMemory.getOperation();
3393  printer << ' ';
3394 
3395  StringRef targetStorageClass =
3396  stringifyStorageClass(copyMemory.target()
3397  .getType()
3398  .cast<spirv::PointerType>()
3399  .getStorageClass());
3400  printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
3401  << ", ";
3402 
3403  StringRef sourceStorageClass =
3404  stringifyStorageClass(copyMemory.source()
3405  .getType()
3406  .cast<spirv::PointerType>()
3407  .getStorageClass());
3408  printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
3409 
3410  SmallVector<StringRef, 4> elidedAttrs;
3411  printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
3412  printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
3413  copyMemory.source_memory_access(),
3414  copyMemory.source_alignment());
3415 
3416  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3417 
3418  Type pointeeType =
3419  copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3420  printer << " : " << pointeeType;
3421 }
3422 
3424  OperationState &state) {
3425  spirv::StorageClass targetStorageClass;
3426  OpAsmParser::OperandType targetPtrInfo;
3427 
3428  spirv::StorageClass sourceStorageClass;
3429  OpAsmParser::OperandType sourcePtrInfo;
3430 
3431  Type elementType;
3432 
3433  if (parseEnumStrAttr(targetStorageClass, parser) ||
3434  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3435  parseEnumStrAttr(sourceStorageClass, parser) ||
3436  parser.parseOperand(sourcePtrInfo) ||
3437  parseMemoryAccessAttributes(parser, state)) {
3438  return failure();
3439  }
3440 
3441  if (!parser.parseOptionalComma()) {
3442  // Parse 2nd memory access attributes.
3443  if (parseSourceMemoryAccessAttributes(parser, state)) {
3444  return failure();
3445  }
3446  }
3447 
3448  if (parser.parseColon() || parser.parseType(elementType))
3449  return failure();
3450 
3451  if (parser.parseOptionalAttrDict(state.attributes))
3452  return failure();
3453 
3454  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3455  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3456 
3457  if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3458  parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3459  return failure();
3460  }
3461 
3462  return success();
3463 }
3464 
3465 static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
3466  Type targetType =
3467  copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3468 
3469  Type sourceType =
3470  copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
3471 
3472  if (targetType != sourceType) {
3473  return copyMemory.emitOpError(
3474  "both operands must be pointers to the same type");
3475  }
3476 
3477  if (failed(verifyMemoryAccessAttribute(copyMemory))) {
3478  return failure();
3479  }
3480 
3481  // TODO - According to the spec:
3482  //
3483  // If two masks are present, the first applies to Target and cannot include
3484  // MakePointerVisible, and the second applies to Source and cannot include
3485  // MakePointerAvailable.
3486  //
3487  // Add such verification here.
3488 
3489  return verifySourceMemoryAccessAttribute(copyMemory);
3490 }
3491 
3492 //===----------------------------------------------------------------------===//
3493 // spv.Transpose
3494 //===----------------------------------------------------------------------===//
3495 
3496 static LogicalResult verifyTranspose(spirv::TransposeOp op) {
3497  auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3498  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3499 
3500  // Verify that the input and output matrices have correct shapes.
3501  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3502  return op.emitError("input matrix rows count must be equal to "
3503  "output matrix columns count");
3504 
3505  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3506  return op.emitError("input matrix columns count must be equal to "
3507  "output matrix rows count");
3508 
3509  // Verify that the input and output matrices have the same component type
3510  if (inputMatrix.getElementType() != resultMatrix.getElementType())
3511  return op.emitError("input and output matrices must have the same "
3512  "component type");
3513 
3514  return success();
3515 }
3516 
3517 //===----------------------------------------------------------------------===//
3518 // spv.MatrixTimesMatrix
3519 //===----------------------------------------------------------------------===//
3520 
3521 static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
3522  auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
3523  auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
3524  auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3525 
3526  // left matrix columns' count and right matrix rows' count must be equal
3527  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
3528  return op.emitError("left matrix columns' count must be equal to "
3529  "the right matrix rows' count");
3530 
3531  // right and result matrices columns' count must be the same
3532  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
3533  return op.emitError(
3534  "right and result matrices must have equal columns' count");
3535 
3536  // right and result matrices component type must be the same
3537  if (rightMatrix.getElementType() != resultMatrix.getElementType())
3538  return op.emitError("right and result matrices' component type must"
3539  " be the same");
3540 
3541  // left and result matrices component type must be the same
3542  if (leftMatrix.getElementType() != resultMatrix.getElementType())
3543  return op.emitError("left and result matrices' component type"
3544  " must be the same");
3545 
3546  // left and result matrices rows count must be the same
3547  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
3548  return op.emitError("left and result matrices must have equal rows'"
3549  " count");
3550 
3551  return success();
3552 }
3553 
3554 //===----------------------------------------------------------------------===//
3555 // spv.SpecConstantComposite
3556 //===----------------------------------------------------------------------===//
3557 
3559  OperationState &state) {
3560 
3561  StringAttr compositeName;
3562  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
3563  state.attributes))
3564  return failure();
3565 
3566  if (parser.parseLParen())
3567  return failure();
3568 
3569  SmallVector<Attribute, 4> constituents;
3570 
3571  do {
3572  // The name of the constituent attribute isn't important
3573  const char *attrName = "spec_const";
3574  FlatSymbolRefAttr specConstRef;
3575  NamedAttrList attrs;
3576 
3577  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
3578  return failure();
3579 
3580  constituents.push_back(specConstRef);
3581  } while (!parser.parseOptionalComma());
3582 
3583  if (parser.parseRParen())
3584  return failure();
3585 
3587  parser.getBuilder().getArrayAttr(constituents));
3588 
3589  Type type;
3590  if (parser.parseColonType(type))
3591  return failure();
3592 
3593  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
3594 
3595  return success();
3596 }
3597 
3598 static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
3599  printer << " ";
3600  printer.printSymbolName(op.sym_name());
3601  printer << " (";
3602  auto constituents = op.constituents().getValue();
3603 
3604  if (!constituents.empty())
3605  llvm::interleaveComma(constituents, printer);
3606 
3607  printer << ") : " << op.type();
3608 }
3609 
3610 static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
3611  auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
3612  auto constituents = constOp.constituents().getValue();
3613 
3614  if (!cType)
3615  return constOp.emitError(
3616  "result type must be a composite type, but provided ")
3617  << constOp.type();
3618 
3619  if (cType.isa<spirv::CooperativeMatrixNVType>())
3620  return constOp.emitError("unsupported composite type ") << cType;
3621  if (constituents.size() != cType.getNumElements())
3622  return constOp.emitError("has incorrect number of operands: expected ")
3623  << cType.getNumElements() << ", but provided "
3624  << constituents.size();
3625 
3626  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
3627  auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
3628 
3629  auto constituentSpecConstOp =
3630  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
3631  constOp->getParentOp(), constituent.getAttr()));
3632 
3633  if (constituentSpecConstOp.default_value().getType() !=
3634  cType.getElementType(index))
3635  return constOp.emitError("has incorrect types of operands: expected ")
3636  << cType.getElementType(index) << ", but provided "
3637  << constituentSpecConstOp.default_value().getType();
3638  }
3639 
3640  return success();
3641 }
3642 
3643 //===----------------------------------------------------------------------===//
3644 // spv.SpecConstantOperation
3645 //===----------------------------------------------------------------------===//
3646 
3648  OperationState &state) {
3649  Region *body = state.addRegion();
3650 
3651  if (parser.parseKeyword("wraps"))
3652  return failure();
3653 
3654  body->push_back(new Block);
3655  Block &block = body->back();
3656  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
3657 
3658  if (!wrappedOp)
3659  return failure();
3660 
3661  OpBuilder builder(parser.getContext());
3662  builder.setInsertionPointToEnd(&block);
3663  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
3664  state.location = wrappedOp->getLoc();
3665 
3666  state.addTypes(wrappedOp->getResult(0).getType());
3667 
3668  if (parser.parseOptionalAttrDict(state.attributes))
3669  return failure();
3670 
3671  return success();
3672 }
3673 
3674 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
3675  printer << " wraps ";
3676  printer.printGenericOp(&op.body().front().front());
3677 }
3678 
3679 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
3680  Block &block = constOp.getRegion().getBlocks().front();
3681 
3682  if (block.getOperations().size() != 2)
3683  return constOp.emitOpError("expected exactly 2 nested ops");
3684 
3685  Operation &enclosedOp = block.getOperations().front();
3686 
3688  return constOp.emitOpError("invalid enclosed op");
3689 
3690  for (auto operand : enclosedOp.getOperands())
3691  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
3692  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
3693  return constOp.emitOpError(
3694  "invalid operand, must be defined by a constant operation");
3695 
3696  return success();
3697 }
3698 
3699 //===----------------------------------------------------------------------===//
3700 // spv.GLSL.FrexpStruct
3701 //===----------------------------------------------------------------------===//
3702 static LogicalResult
3703 verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp) {
3704  spirv::StructType structTy =
3705  frexpStructOp.result().getType().dyn_cast<spirv::StructType>();
3706 
3707  if (structTy.getNumElements() != 2)
3708  return frexpStructOp.emitError("result type must be a struct type "
3709  "with two memebers");
3710 
3711  Type significandTy = structTy.getElementType(0);
3712  Type exponentTy = structTy.getElementType(1);
3713  VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
3714  IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
3715 
3716  Type operandTy = frexpStructOp.operand().getType();
3717  VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
3718  FloatType operandFTy = operandTy.dyn_cast<FloatType>();
3719 
3720  if (significandTy != operandTy)
3721  return frexpStructOp.emitError("member zero of the resulting struct type "
3722  "must be the same type as the operand");
3723 
3724  if (exponentVecTy) {
3725  IntegerType componentIntTy =
3726  exponentVecTy.getElementType().dyn_cast<IntegerType>();
3727  if (!(componentIntTy && componentIntTy.getWidth() == 32))
3728  return frexpStructOp.emitError(
3729  "member one of the resulting struct type must"
3730  "be a scalar or vector of 32 bit integer type");
3731  } else if (!(exponentIntTy && exponentIntTy.getWidth() == 32)) {
3732  return frexpStructOp.emitError(
3733  "member one of the resulting struct type "
3734  "must be a scalar or vector of 32 bit integer type");
3735  }
3736 
3737  // Check that the two member types have the same number of components
3738  if (operandVecTy && exponentVecTy &&
3739  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
3740  return success();
3741 
3742  if (operandFTy && exponentIntTy)
3743  return success();
3744 
3745  return frexpStructOp.emitError(
3746  "member one of the resulting struct type "
3747  "must have the same number of components as the operand type");
3748 }
3749 
3750 //===----------------------------------------------------------------------===//
3751 // spv.GLSL.Ldexp
3752 //===----------------------------------------------------------------------===//
3753 
3754 static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) {
3755  Type significandType = ldexpOp.x().getType();
3756  Type exponentType = ldexpOp.exp().getType();
3757 
3758  if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
3759  return ldexpOp.emitOpError("operands must both be scalars or vectors");
3760 
3761  auto getNumElements = [](Type type) -> unsigned {
3762  if (auto vectorType = type.dyn_cast<VectorType>())
3763  return vectorType.getNumElements();
3764  return 1;
3765  };
3766 
3767  if (getNumElements(significandType) != getNumElements(exponentType))
3768  return ldexpOp.emitOpError(
3769  "operands must have the same number of elements");
3770 
3771  return success();
3772 }
3773 
3774 //===----------------------------------------------------------------------===//
3775 // spv.ImageDrefGather
3776 //===----------------------------------------------------------------------===//
3777 
3778 static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) {
3779  VectorType resultType =
3780  imageDrefGatherOp.result().getType().cast<VectorType>();
3781  auto sampledImageType = imageDrefGatherOp.sampledimage()
3782  .getType()
3783  .cast<spirv::SampledImageType>();
3784  auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
3785 
3786  if (resultType.getNumElements() != 4)
3787  return imageDrefGatherOp.emitOpError(
3788  "result type must be a vector of four components");
3789 
3790  Type elementType = resultType.getElementType();
3791  Type sampledElementType = imageType.getElementType();
3792  if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
3793  return imageDrefGatherOp.emitOpError(
3794  "the component type of result must be the same as sampled type of the "
3795  "underlying image type");
3796 
3797  spirv::Dim imageDim = imageType.getDim();
3798  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
3799 
3800  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
3801  imageDim != spirv::Dim::Rect)
3802  return imageDrefGatherOp.emitOpError(
3803  "the Dim operand of the underlying image type must be 2D, Cube, or "
3804  "Rect");
3805 
3806  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
3807  return imageDrefGatherOp.emitOpError(
3808  "the MS operand of the underlying image type must be 0");
3809 
3810  spirv::ImageOperandsAttr attr = imageDrefGatherOp.imageoperandsAttr();
3811  auto operandArguments = imageDrefGatherOp.operand_arguments();
3812 
3813  return verifyImageOperands(imageDrefGatherOp, attr, operandArguments);
3814 }
3815 
3816 //===----------------------------------------------------------------------===//
3817 // spv.ImageQuerySize
3818 //===----------------------------------------------------------------------===//
3819 
3820 static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
3821  spirv::ImageType imageType =
3822  imageQuerySizeOp.image().getType().cast<spirv::ImageType>();
3823  Type resultType = imageQuerySizeOp.result().getType();
3824 
3825  spirv::Dim dim = imageType.getDim();
3826  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
3827  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
3828  switch (dim) {
3829  case spirv::Dim::Dim1D:
3830  case spirv::Dim::Dim2D:
3831  case spirv::Dim::Dim3D:
3832  case spirv::Dim::Cube:
3833  if (!(samplingInfo == spirv::ImageSamplingInfo::MultiSampled ||
3834  samplerInfo == spirv::ImageSamplerUseInfo::SamplerUnknown ||
3835  samplerInfo == spirv::ImageSamplerUseInfo::NoSampler))
3836  return imageQuerySizeOp.emitError(
3837  "if Dim is 1D, 2D, 3D, or Cube, "
3838  "it must also have either an MS of 1 or a Sampled of 0 or 2");
3839  break;
3840  case spirv::Dim::Buffer:
3841  case spirv::Dim::Rect:
3842  break;
3843  default:
3844  return imageQuerySizeOp.emitError("the Dim operand of the image type must "
3845  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
3846  }
3847 
3848  unsigned componentNumber = 0;
3849  switch (dim) {
3850  case spirv::Dim::Dim1D:
3851  case spirv::Dim::Buffer:
3852  componentNumber = 1;
3853  break;
3854  case spirv::Dim::Dim2D:
3855  case spirv::Dim::Cube:
3856  case spirv::Dim::Rect:
3857  componentNumber = 2;
3858  break;
3859  case spirv::Dim::Dim3D:
3860  componentNumber = 3;
3861  break;
3862  default:
3863  break;
3864  }
3865 
3866  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
3867  componentNumber += 1;
3868 
3869  unsigned resultComponentNumber = 1;
3870  if (auto resultVectorType = resultType.dyn_cast<VectorType>())
3871  resultComponentNumber = resultVectorType.getNumElements();
3872 
3873  if (componentNumber != resultComponentNumber)
3874  return imageQuerySizeOp.emitError("expected the result to have ")
3875  << componentNumber << " component(s), but found "
3876  << resultComponentNumber << " component(s)";
3877 
3878  return success();
3879 }
3880 
3881 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
3882  OpAsmParser &parser,
3883  OperationState &state) {
3884  OpAsmParser::OperandType ptrInfo;
3886  Type type;
3887  auto loc = parser.getCurrentLocation();
3888  SmallVector<Type, 4> indicesTypes;
3889 
3890  if (parser.parseOperand(ptrInfo) ||
3891  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
3892  parser.parseColonType(type) ||
3893  parser.resolveOperand(ptrInfo, type, state.operands))
3894  return failure();
3895 
3896  // Check that the provided indices list is not empty before parsing their
3897  // type list.
3898  if (indicesInfo.empty())
3899  return emitError(state.location) << opName << " expected element";
3900 
3901  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
3902  return failure();
3903 
3904  // Check that the indices types list is not empty and that it has a one-to-one
3905  // mapping to the provided indices.
3906  if (indicesTypes.size() != indicesInfo.size())
3907  return emitError(state.location)
3908  << opName
3909  << " indices types' count must be equal to indices info count";
3910 
3911  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
3912  return failure();
3913 
3914  auto resultType = getElementPtrType(
3915  type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
3916  if (!resultType)
3917  return failure();
3918 
3919  state.addTypes(resultType);
3920  return success();
3921 }
3922 
3923 template <typename Op>
3924 static auto concatElemAndIndices(Op op) {
3925  SmallVector<Value> ret(op.indices().size() + 1);
3926  ret[0] = op.element();
3927  llvm::copy(op.indices(), ret.begin() + 1);
3928  return ret;
3929 }
3930 
3931 //===----------------------------------------------------------------------===//
3932 // spv.InBoundsPtrAccessChainOp
3933 //===----------------------------------------------------------------------===//
3934 
3935 void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
3936  OperationState &state,
3937  Value basePtr, Value element,
3938  ValueRange indices) {
3939  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
3940  assert(type && "Unable to deduce return type based on basePtr and indices");
3941  build(builder, state, type, basePtr, element, indices);
3942 }
3943 
3945  OperationState &state) {
3947  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
3948 }
3949 
3950 static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) {
3951  printAccessChain(op, concatElemAndIndices(op), printer);
3952 }
3953 
3954 static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) {
3955  return verifyAccessChain(accessChainOp, accessChainOp.indices());
3956 }
3957 
3958 //===----------------------------------------------------------------------===//
3959 // spv.PtrAccessChainOp
3960 //===----------------------------------------------------------------------===//
3961 
3962 void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
3963  Value basePtr, Value element,
3964  ValueRange indices) {
3965  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
3966  assert(type && "Unable to deduce return type based on basePtr and indices");
3967  build(builder, state, type, basePtr, element, indices);
3968 }
3969 
3971  OperationState &state) {
3972  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
3973  parser, state);
3974 }
3975 
3976 static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) {
3977  printAccessChain(op, concatElemAndIndices(op), printer);
3978 }
3979 
3980 static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
3981  return verifyAccessChain(accessChainOp, accessChainOp.indices());
3982 }
3983 
3984 // TableGen'erated operation interfaces for querying versions, extensions, and
3985 // capabilities.
3986 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
3987 
3988 // TablenGen'erated operation definitions.
3989 #define GET_OP_CLASSES
3990 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
3991 
3992 namespace mlir {
3993 namespace spirv {
3994 // TableGen'erated operation availability interface implementations.
3995 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
3996 } // namespace spirv
3997 } // namespace mlir
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
This is the representation of an operand reference.
void printFunctionSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spv.mlir.merge op.
Definition: SPIRVOps.cpp:706
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
static ParseResult parseImageOperands(OpAsmParser &parser, spirv::ImageOperandsAttr &attr)
Definition: SPIRVOps.cpp:319
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:233
static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, Value value)
Definition: SPIRVOps.cpp:979
Block * getSuccessor(unsigned i)
Definition: Block.cpp:240
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static void printShiftOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:950
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
iterator begin()
Definition: Block.h:134
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:123
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3881
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:461
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:928
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:501
static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1616
Operation & back()
Definition: Block.h:143
static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1944
static std::string bindingName()
Returns the string name of the Binding decoration.
static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:890
static ParseResult parsePtrAccessChainOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3970
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
static constexpr const char kMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:37
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:540
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:122
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1103
Block represents an ordered list of Operations.
Definition: Block.h:29
A symbol reference with a reference path containing a single element.
static ParseResult parseEntryPointOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1841
static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3944
static LogicalResult verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp)
Definition: SPIRVOps.cpp:3703
static ParseResult parseSelectionOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2870
static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3558
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
Value getOperand(unsigned idx)
Definition: Operation.h:219
static constexpr const char kSourceMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:38
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:327
OpListType & getOperations()
Definition: Block.h:128
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyAtomicUpdateOp(Operation *op)
Definition: SPIRVOps.cpp:774
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments, to the list of operation attributes in result.
bool isa() const
Definition: Attributes.h:107
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:879
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getNumOperands()
Definition: Operation.h:215
static constexpr const char kDefaultValueAttrName[]
Definition: SPIRVOps.cpp:46
static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:552
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
NoneType getNoneType()
Definition: Builders.cpp:75
static Type getUnaryOpResultType(Builder &builder, Type operandType)
Result of a logical op must be a scalar or vector of boolean type.
Definition: SPIRVOps.cpp:896
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:215
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
static ParseResult parseCompositeConstructOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1422
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:525
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
Operation & front()
Definition: Block.h:144
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:589
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
Definition: SPIRVOps.cpp:123
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1163
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spv.Branch to the given dstBlock.
Definition: SPIRVOps.cpp:2498
static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2316
static bool isNestedInFunctionOpInterface(Operation *op)
Returns true if the given op is a function-like op or nested in a function-like op without a module-l...
Definition: SPIRVOps.cpp:69
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix)
Definition: SPIRVOps.cpp:3268
static ParseResult parseSpecConstantOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2986
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1098
static StringRef stringifyTypeName()
static constexpr const char kControl[]
Definition: SPIRVOps.cpp:45
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseComma()=0
Parse a , token.
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
static constexpr const bool value
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attrName&#39;...
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:746
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::OperandType > &argNames, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< NamedAttrList > &argAttrs, SmallVectorImpl< Location > &argLocations, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< NamedAttrList > &resultAttrs)
Parses a function signature using parser.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
StringRef stringifyTypeName< FloatType >()
Definition: SPIRVOps.cpp:768
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
static constexpr const char kInterfaceAttrName[]
Definition: SPIRVOps.cpp:53
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
static constexpr const char kSpecIdAttrName[]
Definition: SPIRVOps.cpp:56
static constexpr const char kTypeAttrName[]
Definition: SPIRVOps.cpp:57
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:337
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual ParseResult resolveOperand(const OperandType &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool empty()
Definition: Region.h:60
static ParseResult parseCompositeInsertOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1563
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
static unsigned getBitWidth(Type type)
Definition: SPIRVOps.cpp:617
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:618
void addOperands(ValueRange newOperands)
static ParseResult parseCompositeExtractOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1506
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region if present.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1156
U dyn_cast() const
Definition: Types.h:244
iterator end()
Definition: Block.h:135
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:417
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static constexpr const char kIndicesAttrName[]
Definition: SPIRVOps.cpp:51
static constexpr const char kClusterSize[]
Definition: SPIRVOps.cpp:44
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:345
Type getElementType() const
Definition: SPIRVTypes.cpp:62
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, Optional< spirv::MemoryAccess > memoryAccessAtrrValue=None, Optional< uint32_t > alignmentAttrValue=None)
Definition: SPIRVOps.cpp:291
static constexpr const char kAlignmentAttrName[]
Definition: SPIRVOps.cpp:40
Type getElementType() const
Returns the elements&#39; type (i.e, single element type).
static constexpr const char kCallee[]
Definition: SPIRVOps.cpp:43
static ParseResult parseBranchConditionalOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1334
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual ParseResult parseRParen()=0
Parse a ) token.
Block & back()
Definition: Region.h:64
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, Value lhs, Value rhs)
Definition: SPIRVOps.cpp:967
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation...
static constexpr const char kEqualSemanticsAttrName[]
Definition: SPIRVOps.cpp:48
static ParseResult parseExecutionModeOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1902
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:215
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:32
static constexpr const char kValuesAttrName[]
Definition: SPIRVOps.cpp:60
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static WalkResult advance()
Definition: Visitors.h:51
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:833
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
This represents an operation in an abstracted form, suitable for use with the builder APIs...
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2676
static ParseResult parseLogicalUnaryOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:904
static ParseResult parseGlobalVariableOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2166
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:933
static constexpr const char kMemoryScopeAttrName[]
Definition: SPIRVOps.cpp:54
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition: SPIRVOps.cpp:850
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
static bool isDirectInModuleLikeOp(Operation *op)
Returns true if the given op is an module-like op that maintains a symbol table.
Definition: SPIRVOps.cpp:81
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:993
static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef< Ty > enumValues, function_ref< StringRef(Ty)> stringifyFn)
Definition: SPIRVOps.cpp:106
virtual ParseResult parseRSquare()=0
Parse a ] token.
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2478
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue)
Definition: SPIRVOps.cpp:717
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:85
static constexpr const char kFnNameAttrName[]
Definition: SPIRVOps.cpp:49
static constexpr const char kSemanticsAttrName[]
Definition: SPIRVOps.cpp:55
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access attributes attached to a memory access operand/pointer.
Definition: SPIRVOps.cpp:202
void addSuccessors(Block *successor)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
Definition: SPIRVOps.cpp:181
static ParseResult parseAtomicExchangeOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1237
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
static constexpr const char kInitializerAttrName[]
Definition: SPIRVOps.cpp:52
static ParseResult parseLogicalBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:916
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory)
Definition: SPIRVOps.cpp:3465
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:989
NamedAttrList attributes
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3237
static constexpr const char kSourceAlignmentAttrName[]
Definition: SPIRVOps.cpp:41
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3039
Region * addRegion()
Create a region that should be attached to the operation.
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:991
static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3647
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:343
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition: SPIRVOps.cpp:370
static constexpr const char kExecutionScopeAttrName[]
Definition: SPIRVOps.cpp:47
Type getType() const
Return the type of this value.
Definition: Value.h:117
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions...
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:124
unsigned getNumSuccessors()
Definition: Block.cpp:236
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:797
static constexpr const char kCompositeSpecConstituentsName[]
Definition: SPIRVOps.cpp:61
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op)
Definition: SPIRVOps.cpp:3354
U dyn_cast() const
Definition: Attributes.h:117
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:678
static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op)
Definition: SPIRVOps.cpp:3521
type_range getTypes() const
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:957
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
virtual ParseResult parseType(Type &result)=0
Parse a type.
static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3107
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:624
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
This class implements the operand iterators for the Operation class.
This provides public APIs that all operations should have.
static LogicalResult verifyTranspose(spirv::TransposeOp op)
Definition: SPIRVOps.cpp:3496
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
static LogicalResult verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)
Definition: SPIRVOps.cpp:3329
static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3291
static constexpr const char kBranchWeightAttrName[]
Definition: SPIRVOps.cpp:42
static auto concatElemAndIndices(Op op)
Definition: SPIRVOps.cpp:3924
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, Optional< spirv::MemoryAccess > memoryAccessAtrrValue=None, Optional< uint32_t > alignmentAttrValue=None)
Definition: SPIRVOps.cpp:261
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
static constexpr const char kGroupOperationAttrName[]
Definition: SPIRVOps.cpp:50
virtual ParseResult parseEqual()=0
Parse a = token.
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:232
unsigned getNumColumns() const
Returns the number of columns.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
unsigned getNumRows() const
Returns the number of rows.
StringRef stringifyTypeName< IntegerType >()
Definition: SPIRVOps.cpp:763
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
bool isa() const
Definition: Types.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:61
static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2358
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, spirv::ImageOperandsAttr attr)
Definition: SPIRVOps.cpp:334
static constexpr const char kValueAttrName[]
Definition: SPIRVOps.cpp:59
This class helps build Operations.
Definition: Builders.h:177
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
This class provides an abstraction over the different types of ranges over Values.
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:341
static ParseResult parseCopyMemoryOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:3423
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:2422
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success. ...
static ParseResult parseAccessChainOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1051
static constexpr const char kUnequalSemanticsAttrName[]
Definition: SPIRVOps.cpp:58
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)=0
Parse an optional -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attr...
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
Square brackets surrounding zero or more operands.
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: SPIRVOps.cpp:1108
U cast() const
Definition: Types.h:250
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp)
Definition: SPIRVOps.cpp:1195
SmallVector< Type, 4 > types
Types of the results of this operation.
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:246