MLIR  15.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/bit.h"
34 
35 using namespace mlir;
36 
37 // TODO: generate these strings using ODS.
38 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
39 static constexpr const char kSourceMemoryAccessAttrName[] =
40  "source_memory_access";
41 static constexpr const char kAlignmentAttrName[] = "alignment";
42 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
43 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
44 static constexpr const char kCallee[] = "callee";
45 static constexpr const char kClusterSize[] = "cluster_size";
46 static constexpr const char kControl[] = "control";
47 static constexpr const char kDefaultValueAttrName[] = "default_value";
48 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
49 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
50 static constexpr const char kFnNameAttrName[] = "fn";
51 static constexpr const char kGroupOperationAttrName[] = "group_operation";
52 static constexpr const char kIndicesAttrName[] = "indices";
53 static constexpr const char kInitializerAttrName[] = "initializer";
54 static constexpr const char kInterfaceAttrName[] = "interface";
55 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
56 static constexpr const char kSemanticsAttrName[] = "semantics";
57 static constexpr const char kSpecIdAttrName[] = "spec_id";
58 static constexpr const char kTypeAttrName[] = "type";
59 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
60 static constexpr const char kValueAttrName[] = "value";
61 static constexpr const char kValuesAttrName[] = "values";
62 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
63 
64 //===----------------------------------------------------------------------===//
65 // Common utility functions
66 //===----------------------------------------------------------------------===//
67 
69  OperationState &result) {
71  Type type;
72  // If the operand list is in-between parentheses, then we have a generic form.
73  // (see the fallback in `printOneResultOp`).
74  SMLoc loc = parser.getCurrentLocation();
75  if (!parser.parseOptionalLParen()) {
76  if (parser.parseOperandList(ops) || parser.parseRParen() ||
77  parser.parseOptionalAttrDict(result.attributes) ||
78  parser.parseColon() || parser.parseType(type))
79  return failure();
80  auto fnType = type.dyn_cast<FunctionType>();
81  if (!fnType) {
82  parser.emitError(loc, "expected function type");
83  return failure();
84  }
85  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
86  return failure();
87  result.addTypes(fnType.getResults());
88  return success();
89  }
90  return failure(parser.parseOperandList(ops) ||
91  parser.parseOptionalAttrDict(result.attributes) ||
92  parser.parseColonType(type) ||
93  parser.resolveOperands(ops, type, result.operands) ||
94  parser.addTypeToList(type, result.types));
95 }
96 
97 static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
98  assert(op->getNumResults() == 1 && "op should have one result");
99 
100  // If not all the operand and result types are the same, just use the
101  // generic assembly form to avoid omitting information in printing.
102  auto resultType = op->getResult(0).getType();
103  if (llvm::any_of(op->getOperandTypes(),
104  [&](Type type) { return type != resultType; })) {
105  p.printGenericOp(op, /*printOpName=*/false);
106  return;
107  }
108 
109  p << ' ';
110  p.printOperands(op->getOperands());
112  // Now we can output only one type for all operands and the result.
113  p << " : " << resultType;
114 }
115 
116 /// Returns true if the given op is a function-like op or nested in a
117 /// function-like op without a module-like op in the middle.
119  if (!op)
120  return false;
121  if (op->hasTrait<OpTrait::SymbolTable>())
122  return false;
123  if (isa<FunctionOpInterface>(op))
124  return true;
126 }
127 
128 /// Returns true if the given op is an module-like op that maintains a symbol
129 /// table.
131  return op && op->hasTrait<OpTrait::SymbolTable>();
132 }
133 
135  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
136  if (!constOp) {
137  return failure();
138  }
139  auto valueAttr = constOp.value();
140  auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
141  if (!integerValueAttr) {
142  return failure();
143  }
144 
145  if (integerValueAttr.getType().isSignlessInteger())
146  value = integerValueAttr.getInt();
147  else
148  value = integerValueAttr.getSInt();
149 
150  return success();
151 }
152 
153 template <typename Ty>
154 static ArrayAttr
156  function_ref<StringRef(Ty)> stringifyFn) {
157  if (enumValues.empty()) {
158  return nullptr;
159  }
160  SmallVector<StringRef, 1> enumValStrs;
161  enumValStrs.reserve(enumValues.size());
162  for (auto val : enumValues) {
163  enumValStrs.emplace_back(stringifyFn(val));
164  }
165  return builder.getStrArrayAttr(enumValStrs);
166 }
167 
168 /// Parses the next string attribute in `parser` as an enumerant of the given
169 /// `EnumClass`.
170 template <typename EnumClass>
171 static ParseResult
173  StringRef attrName = spirv::attributeName<EnumClass>()) {
174  Attribute attrVal;
175  NamedAttrList attr;
176  auto loc = parser.getCurrentLocation();
177  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
178  attrName, attr)) {
179  return failure();
180  }
181  if (!attrVal.isa<StringAttr>()) {
182  return parser.emitError(loc, "expected ")
183  << attrName << " attribute specified as string";
184  }
185  auto attrOptional =
186  spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
187  if (!attrOptional) {
188  return parser.emitError(loc, "invalid ")
189  << attrName << " attribute specification: " << attrVal;
190  }
191  value = *attrOptional;
192  return success();
193 }
194 
195 /// Parses the next string attribute in `parser` as an enumerant of the given
196 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
197 /// attribute with the enum class's name as attribute name.
198 template <typename EnumClass>
199 static ParseResult
200 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
201  StringRef attrName = spirv::attributeName<EnumClass>()) {
202  if (parseEnumStrAttr(value, parser)) {
203  return failure();
204  }
205  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
206  llvm::bit_cast<int32_t>(value)));
207  return success();
208 }
209 
210 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
211 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
212 /// the enum class's name as attribute name.
213 template <typename EnumClass>
214 static ParseResult
215 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
216  OperationState &state,
217  StringRef attrName = spirv::attributeName<EnumClass>()) {
218  if (parseEnumKeywordAttr(value, parser)) {
219  return failure();
220  }
221  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
222  llvm::bit_cast<int32_t>(value)));
223  return success();
224 }
225 
226 /// Parses Function, Selection and Loop control attributes. If no control is
227 /// specified, "None" is used as a default.
228 template <typename EnumClass>
229 static ParseResult
231  StringRef attrName = spirv::attributeName<EnumClass>()) {
232  if (succeeded(parser.parseOptionalKeyword(kControl))) {
233  EnumClass control;
234  if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
235  parser.parseRParen())
236  return failure();
237  return success();
238  }
239  // Set control to "None" otherwise.
240  Builder builder = parser.getBuilder();
241  state.addAttribute(attrName, builder.getI32IntegerAttr(0));
242  return success();
243 }
244 
245 /// Parses optional memory access attributes attached to a memory access
246 /// operand/pointer. Specifically, parses the following syntax:
247 /// (`[` memory-access `]`)?
248 /// where:
249 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
250 /// integer-literal | `"NonTemporal"`
252  OperationState &state) {
253  // Parse an optional list of attributes staring with '['
254  if (parser.parseOptionalLSquare()) {
255  // Nothing to do
256  return success();
257  }
258 
259  spirv::MemoryAccess memoryAccessAttr;
260  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
262  return failure();
263  }
264 
265  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
266  // Parse integer attribute for alignment.
267  Attribute alignmentAttr;
268  Type i32Type = parser.getBuilder().getIntegerType(32);
269  if (parser.parseComma() ||
270  parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
271  state.attributes)) {
272  return failure();
273  }
274  }
275  return parser.parseRSquare();
276 }
277 
278 // TODO Make sure to merge this and the previous function into one template
279 // parameterized by memory access attribute name and alignment. Doing so now
280 // results in VS2017 in producing an internal error (at the call site) that's
281 // not detailed enough to understand what is happening.
283  OperationState &state) {
284  // Parse an optional list of attributes staring with '['
285  if (parser.parseOptionalLSquare()) {
286  // Nothing to do
287  return success();
288  }
289 
290  spirv::MemoryAccess memoryAccessAttr;
291  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
293  return failure();
294  }
295 
296  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
297  // Parse integer attribute for alignment.
298  Attribute alignmentAttr;
299  Type i32Type = parser.getBuilder().getIntegerType(32);
300  if (parser.parseComma() ||
301  parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
302  state.attributes)) {
303  return failure();
304  }
305  }
306  return parser.parseRSquare();
307 }
308 
309 template <typename MemoryOpTy>
311  MemoryOpTy memoryOp, OpAsmPrinter &printer,
312  SmallVectorImpl<StringRef> &elidedAttrs,
313  Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
314  Optional<uint32_t> alignmentAttrValue = None) {
315  // Print optional memory access attribute.
316  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
317  : memoryOp.memory_access())) {
318  elidedAttrs.push_back(kMemoryAccessAttrName);
319 
320  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
321 
322  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
323  // Print integer alignment attribute.
324  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
325  : memoryOp.alignment())) {
326  elidedAttrs.push_back(kAlignmentAttrName);
327  printer << ", " << alignment;
328  }
329  }
330  printer << "]";
331  }
332  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
333 }
334 
335 // TODO Make sure to merge this and the previous function into one template
336 // parameterized by memory access attribute name and alignment. Doing so now
337 // results in VS2017 in producing an internal error (at the call site) that's
338 // not detailed enough to understand what is happening.
339 template <typename MemoryOpTy>
341  MemoryOpTy memoryOp, OpAsmPrinter &printer,
342  SmallVectorImpl<StringRef> &elidedAttrs,
343  Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
344  Optional<uint32_t> alignmentAttrValue = None) {
345 
346  printer << ", ";
347 
348  // Print optional memory access attribute.
349  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
350  : memoryOp.memory_access())) {
351  elidedAttrs.push_back(kSourceMemoryAccessAttrName);
352 
353  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
354 
355  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
356  // Print integer alignment attribute.
357  if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
358  : memoryOp.alignment())) {
359  elidedAttrs.push_back(kSourceAlignmentAttrName);
360  printer << ", " << alignment;
361  }
362  }
363  printer << "]";
364  }
365  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
366 }
367 
369  spirv::ImageOperandsAttr &attr) {
370  // Expect image operands
371  if (parser.parseOptionalLSquare())
372  return success();
373 
374  spirv::ImageOperands imageOperands;
375  if (parseEnumStrAttr(imageOperands, parser))
376  return failure();
377 
378  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
379 
380  return parser.parseRSquare();
381 }
382 
383 static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
384  spirv::ImageOperandsAttr attr) {
385  if (attr) {
386  auto strImageOperands = stringifyImageOperands(attr.getValue());
387  printer << "[\"" << strImageOperands << "\"]";
388  }
389 }
390 
391 template <typename Op>
393  spirv::ImageOperandsAttr attr,
394  Operation::operand_range operands) {
395  if (!attr) {
396  if (operands.empty())
397  return success();
398 
399  return imageOp.emitError("the Image Operands should encode what operands "
400  "follow, as per Image Operands");
401  }
402 
403  // TODO: Add the validation rules for the following Image Operands.
404  spirv::ImageOperands noSupportOperands =
405  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
406  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
407  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
408  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
409  spirv::ImageOperands::MakeTexelAvailable |
410  spirv::ImageOperands::MakeTexelVisible |
411  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
412 
413  if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
414  llvm_unreachable("unimplemented operands of Image Operands");
415 
416  return success();
417 }
418 
420  bool requireSameBitWidth = true,
421  bool skipBitWidthCheck = false) {
422  // Some CastOps have no limit on bit widths for result and operand type.
423  if (skipBitWidthCheck)
424  return success();
425 
426  Type operandType = op->getOperand(0).getType();
427  Type resultType = op->getResult(0).getType();
428 
429  // ODS checks that result type and operand type have the same shape.
430  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
431  operandType = vectorType.getElementType();
432  resultType = resultType.cast<VectorType>().getElementType();
433  }
434 
435  if (auto coopMatrixType =
436  operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
437  operandType = coopMatrixType.getElementType();
438  resultType =
440  }
441 
442  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
443  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
444  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
445 
446  if (requireSameBitWidth) {
447  if (!isSameBitWidth) {
448  return op->emitOpError(
449  "expected the same bit widths for operand type and result "
450  "type, but provided ")
451  << operandType << " and " << resultType;
452  }
453  return success();
454  }
455 
456  if (isSameBitWidth) {
457  return op->emitOpError(
458  "expected the different bit widths for operand type and result "
459  "type, but provided ")
460  << operandType << " and " << resultType;
461  }
462  return success();
463 }
464 
465 template <typename MemoryOpTy>
466 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
467  // ODS checks for attributes values. Just need to verify that if the
468  // memory-access attribute is Aligned, then the alignment attribute must be
469  // present.
470  auto *op = memoryOp.getOperation();
471  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
472  if (!memAccessAttr) {
473  // Alignment attribute shouldn't be present if memory access attribute is
474  // not present.
475  if (op->getAttr(kAlignmentAttrName)) {
476  return memoryOp.emitOpError(
477  "invalid alignment specification without aligned memory access "
478  "specification");
479  }
480  return success();
481  }
482 
483  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
484  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
485 
486  if (!memAccess) {
487  return memoryOp.emitOpError("invalid memory access specifier: ")
488  << memAccessVal;
489  }
490 
491  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
492  if (!op->getAttr(kAlignmentAttrName)) {
493  return memoryOp.emitOpError("missing alignment value");
494  }
495  } else {
496  if (op->getAttr(kAlignmentAttrName)) {
497  return memoryOp.emitOpError(
498  "invalid alignment specification with non-aligned memory access "
499  "specification");
500  }
501  }
502  return success();
503 }
504 
505 // TODO Make sure to merge this and the previous function into one template
506 // parameterized by memory access attribute name and alignment. Doing so now
507 // results in VS2017 in producing an internal error (at the call site) that's
508 // not detailed enough to understand what is happening.
509 template <typename MemoryOpTy>
511  // ODS checks for attributes values. Just need to verify that if the
512  // memory-access attribute is Aligned, then the alignment attribute must be
513  // present.
514  auto *op = memoryOp.getOperation();
515  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
516  if (!memAccessAttr) {
517  // Alignment attribute shouldn't be present if memory access attribute is
518  // not present.
519  if (op->getAttr(kSourceAlignmentAttrName)) {
520  return memoryOp.emitOpError(
521  "invalid alignment specification without aligned memory access "
522  "specification");
523  }
524  return success();
525  }
526 
527  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
528  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
529 
530  if (!memAccess) {
531  return memoryOp.emitOpError("invalid memory access specifier: ")
532  << memAccessVal;
533  }
534 
535  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
536  if (!op->getAttr(kSourceAlignmentAttrName)) {
537  return memoryOp.emitOpError("missing alignment value");
538  }
539  } else {
540  if (op->getAttr(kSourceAlignmentAttrName)) {
541  return memoryOp.emitOpError(
542  "invalid alignment specification with non-aligned memory access "
543  "specification");
544  }
545  }
546  return success();
547 }
548 
549 static LogicalResult
550 verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
551  // According to the SPIR-V specification:
552  // "Despite being a mask and allowing multiple bits to be combined, it is
553  // invalid for more than one of these four bits to be set: Acquire, Release,
554  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
555  // Release semantics is done by setting the AcquireRelease bit, not by setting
556  // two bits."
557  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
558  spirv::MemorySemantics::Release |
559  spirv::MemorySemantics::AcquireRelease |
560  spirv::MemorySemantics::SequentiallyConsistent;
561 
562  auto bitCount = llvm::countPopulation(
563  static_cast<uint32_t>(memorySemantics & atMostOneInSet));
564  if (bitCount > 1) {
565  return op->emitError(
566  "expected at most one of these four memory constraints "
567  "to be set: `Acquire`, `Release`,"
568  "`AcquireRelease` or `SequentiallyConsistent`");
569  }
570  return success();
571 }
572 
573 template <typename LoadStoreOpTy>
574 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
575  Value val) {
576  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
577  // type of the pointer and the type of the value are the same
578  //
579  // TODO: Check that the value type satisfies restrictions of
580  // SPIR-V OpLoad/OpStore operations
581  if (val.getType() !=
582  ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
583  return op.emitOpError("mismatch in result type and pointer type");
584  }
585  return success();
586 }
587 
588 template <typename BlockReadWriteOpTy>
590  Value ptr, Value val) {
591  auto valType = val.getType();
592  if (auto valVecTy = valType.dyn_cast<VectorType>())
593  valType = valVecTy.getElementType();
594 
595  if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
596  return op.emitOpError("mismatch in result type and pointer type");
597  }
598  return success();
599 }
600 
602  OperationState &state) {
603  auto builtInName = llvm::convertToSnakeFromCamelCase(
604  stringifyDecoration(spirv::Decoration::BuiltIn));
605  if (succeeded(parser.parseOptionalKeyword("bind"))) {
606  Attribute set, binding;
607  // Parse optional descriptor binding
608  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
609  stringifyDecoration(spirv::Decoration::DescriptorSet));
610  auto bindingName = llvm::convertToSnakeFromCamelCase(
611  stringifyDecoration(spirv::Decoration::Binding));
612  Type i32Type = parser.getBuilder().getIntegerType(32);
613  if (parser.parseLParen() ||
614  parser.parseAttribute(set, i32Type, descriptorSetName,
615  state.attributes) ||
616  parser.parseComma() ||
617  parser.parseAttribute(binding, i32Type, bindingName,
618  state.attributes) ||
619  parser.parseRParen()) {
620  return failure();
621  }
622  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
623  StringAttr builtIn;
624  if (parser.parseLParen() ||
625  parser.parseAttribute(builtIn, builtInName, state.attributes) ||
626  parser.parseRParen()) {
627  return failure();
628  }
629  }
630 
631  // Parse other attributes
632  if (parser.parseOptionalAttrDict(state.attributes))
633  return failure();
634 
635  return success();
636 }
637 
639  SmallVectorImpl<StringRef> &elidedAttrs) {
640  // Print optional descriptor binding
641  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
642  stringifyDecoration(spirv::Decoration::DescriptorSet));
643  auto bindingName = llvm::convertToSnakeFromCamelCase(
644  stringifyDecoration(spirv::Decoration::Binding));
645  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
646  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
647  if (descriptorSet && binding) {
648  elidedAttrs.push_back(descriptorSetName);
649  elidedAttrs.push_back(bindingName);
650  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
651  << ")";
652  }
653 
654  // Print BuiltIn attribute if present
655  auto builtInName = llvm::convertToSnakeFromCamelCase(
656  stringifyDecoration(spirv::Decoration::BuiltIn));
657  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
658  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
659  elidedAttrs.push_back(builtInName);
660  }
661 
662  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
663 }
664 
665 // Get bit width of types.
666 static unsigned getBitWidth(Type type) {
667  if (type.isa<spirv::PointerType>()) {
668  // Just return 64 bits for pointer types for now.
669  // TODO: Make sure not caller relies on the actual pointer width value.
670  return 64;
671  }
672 
673  if (type.isIntOrFloat())
674  return type.getIntOrFloatBitWidth();
675 
676  if (auto vectorType = type.dyn_cast<VectorType>()) {
677  assert(vectorType.getElementType().isIntOrFloat());
678  return vectorType.getNumElements() *
679  vectorType.getElementType().getIntOrFloatBitWidth();
680  }
681  llvm_unreachable("unhandled bit width computation for type");
682 }
683 
684 /// Walks the given type hierarchy with the given indices, potentially down
685 /// to component granularity, to select an element type. Returns null type and
686 /// emits errors with the given loc on failure.
687 static Type
689  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
690  if (indices.empty()) {
691  emitErrorFn("expected at least one index for spv.CompositeExtract");
692  return nullptr;
693  }
694 
695  for (auto index : indices) {
696  if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
697  if (cType.hasCompileTimeKnownNumElements() &&
698  (index < 0 ||
699  static_cast<uint64_t>(index) >= cType.getNumElements())) {
700  emitErrorFn("index ") << index << " out of bounds for " << type;
701  return nullptr;
702  }
703  type = cType.getElementType(index);
704  } else {
705  emitErrorFn("cannot extract from non-composite type ")
706  << type << " with index " << index;
707  return nullptr;
708  }
709  }
710  return type;
711 }
712 
713 static Type
715  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
716  auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
717  if (!indicesArrayAttr) {
718  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
719  return nullptr;
720  }
721  if (indicesArrayAttr.empty()) {
722  emitErrorFn("expected at least one index for spv.CompositeExtract");
723  return nullptr;
724  }
725 
726  SmallVector<int32_t, 2> indexVals;
727  for (auto indexAttr : indicesArrayAttr) {
728  auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
729  if (!indexIntAttr) {
730  emitErrorFn("expected an 32-bit integer for index, but found '")
731  << indexAttr << "'";
732  return nullptr;
733  }
734  indexVals.push_back(indexIntAttr.getInt());
735  }
736  return getElementType(type, indexVals, emitErrorFn);
737 }
738 
739 static Type getElementType(Type type, Attribute indices, Location loc) {
740  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
741  return ::mlir::emitError(loc, err);
742  };
743  return getElementType(type, indices, errorFn);
744 }
745 
746 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
747  SMLoc loc) {
748  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
749  return parser.emitError(loc, err);
750  };
751  return getElementType(type, indices, errorFn);
752 }
753 
754 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
755 static inline bool isMergeBlock(Block &block) {
756  return !block.empty() && std::next(block.begin()) == block.end() &&
757  isa<spirv::MergeOp>(block.front());
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // Common parsers and printers
762 //===----------------------------------------------------------------------===//
763 
764 // Parses an atomic update op. If the update op does not take a value (like
765 // AtomicIIncrement) `hasValue` must be false.
767  OperationState &state, bool hasValue) {
768  spirv::Scope scope;
769  spirv::MemorySemantics memoryScope;
771  OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
772  Type type;
773  SMLoc loc;
774  if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
775  parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
776  parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
777  parser.getCurrentLocation(&loc) || parser.parseColonType(type))
778  return failure();
779 
780  auto ptrType = type.dyn_cast<spirv::PointerType>();
781  if (!ptrType)
782  return parser.emitError(loc, "expected pointer type");
783 
784  SmallVector<Type, 2> operandTypes;
785  operandTypes.push_back(ptrType);
786  if (hasValue)
787  operandTypes.push_back(ptrType.getPointeeType());
788  if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
789  state.operands))
790  return failure();
791  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
792 }
793 
794 // Prints an atomic update op.
795 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
796  printer << " \"";
797  auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
798  printer << spirv::stringifyScope(
799  static_cast<spirv::Scope>(scopeAttr.getInt()))
800  << "\" \"";
801  auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
802  printer << spirv::stringifyMemorySemantics(
803  static_cast<spirv::MemorySemantics>(
804  memorySemanticsAttr.getInt()))
805  << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
806 }
807 
808 template <typename T>
809 static StringRef stringifyTypeName();
810 
811 template <>
813  return "integer";
814 }
815 
816 template <>
818  return "float";
819 }
820 
821 // Verifies an atomic update op.
822 template <typename ExpectedElementType>
824  auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
825  auto elementType = ptrType.getPointeeType();
826  if (!elementType.isa<ExpectedElementType>())
827  return op->emitOpError() << "pointer operand must point to an "
828  << stringifyTypeName<ExpectedElementType>()
829  << " value, found " << elementType;
830 
831  if (op->getNumOperands() > 1) {
832  auto valueType = op->getOperand(1).getType();
833  if (valueType != elementType)
834  return op->emitOpError("expected value to have the same type as the "
835  "pointer operand's pointee type ")
836  << elementType << ", but found " << valueType;
837  }
838  auto memorySemantics = static_cast<spirv::MemorySemantics>(
839  op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
840  if (failed(verifyMemorySemantics(op, memorySemantics))) {
841  return failure();
842  }
843  return success();
844 }
845 
847  OperationState &state) {
848  spirv::Scope executionScope;
849  spirv::GroupOperation groupOperation;
851  if (parseEnumStrAttr(executionScope, parser, state,
853  parseEnumStrAttr(groupOperation, parser, state,
855  parser.parseOperand(valueInfo))
856  return failure();
857 
860  clusterSizeInfo = OpAsmParser::UnresolvedOperand();
861  if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
862  parser.parseRParen())
863  return failure();
864  }
865 
866  Type resultType;
867  if (parser.parseColonType(resultType))
868  return failure();
869 
870  if (parser.resolveOperand(valueInfo, resultType, state.operands))
871  return failure();
872 
873  if (clusterSizeInfo) {
874  Type i32Type = parser.getBuilder().getIntegerType(32);
875  if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
876  return failure();
877  }
878 
879  return parser.addTypeToList(resultType, state.types);
880 }
881 
883  OpAsmPrinter &printer) {
884  printer << " \""
885  << stringifyScope(static_cast<spirv::Scope>(
886  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
887  .getInt()))
888  << "\" \""
889  << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
890  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
891  .getInt()))
892  << "\" " << groupOp->getOperand(0);
893 
894  if (groupOp->getNumOperands() > 1)
895  printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
896  printer << " : " << groupOp->getResult(0).getType();
897 }
898 
900  spirv::Scope scope = static_cast<spirv::Scope>(
901  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
902  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
903  return groupOp->emitOpError(
904  "execution scope must be 'Workgroup' or 'Subgroup'");
905 
906  spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
907  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
908  if (operation == spirv::GroupOperation::ClusteredReduce &&
909  groupOp->getNumOperands() == 1)
910  return groupOp->emitOpError("cluster size operand must be provided for "
911  "'ClusteredReduce' group operation");
912  if (groupOp->getNumOperands() > 1) {
913  Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
914  int32_t clusterSize = 0;
915 
916  // TODO: support specialization constant here.
917  if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
918  return groupOp->emitOpError(
919  "cluster size operand must come from a constant op");
920 
921  if (!llvm::isPowerOf2_32(clusterSize))
922  return groupOp->emitOpError(
923  "cluster size operand must be a power of two");
924  }
925  return success();
926 }
927 
928 /// Result of a logical op must be a scalar or vector of boolean type.
929 static Type getUnaryOpResultType(Type operandType) {
930  Builder builder(operandType.getContext());
931  Type resultType = builder.getIntegerType(1);
932  if (auto vecType = operandType.dyn_cast<VectorType>())
933  return VectorType::get(vecType.getNumElements(), resultType);
934  return resultType;
935 }
936 
938  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
939  return op->emitError("expected the same type for the first operand and "
940  "result, but provided ")
941  << op->getOperand(0).getType() << " and "
942  << op->getResult(0).getType();
943  }
944  return success();
945 }
946 
947 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
948  Value lhs, Value rhs) {
949  assert(lhs.getType() == rhs.getType());
950 
951  Type boolType = builder.getI1Type();
952  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
953  boolType = VectorType::get(vecType.getShape(), boolType);
954  state.addTypes(boolType);
955 
956  state.addOperands({lhs, rhs});
957 }
958 
959 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
960  Value value) {
961  Type boolType = builder.getI1Type();
962  if (auto vecType = value.getType().dyn_cast<VectorType>())
963  boolType = VectorType::get(vecType.getShape(), boolType);
964  state.addTypes(boolType);
965 
966  state.addOperands(value);
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // spv.AccessChainOp
971 //===----------------------------------------------------------------------===//
972 
973 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
974  auto ptrType = type.dyn_cast<spirv::PointerType>();
975  if (!ptrType) {
976  emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
977  "to composite type, but provided ")
978  << type;
979  return nullptr;
980  }
981 
982  auto resultType = ptrType.getPointeeType();
983  auto resultStorageClass = ptrType.getStorageClass();
984  int32_t index = 0;
985 
986  for (auto indexSSA : indices) {
987  auto cType = resultType.dyn_cast<spirv::CompositeType>();
988  if (!cType) {
989  emitError(baseLoc,
990  "'spv.AccessChain' op cannot extract from non-composite type ")
991  << resultType << " with index " << index;
992  return nullptr;
993  }
994  index = 0;
995  if (resultType.isa<spirv::StructType>()) {
996  Operation *op = indexSSA.getDefiningOp();
997  if (!op) {
998  emitError(baseLoc, "'spv.AccessChain' op index must be an "
999  "integer spv.Constant to access "
1000  "element of spv.struct");
1001  return nullptr;
1002  }
1003 
1004  // TODO: this should be relaxed to allow
1005  // integer literals of other bitwidths.
1006  if (failed(extractValueFromConstOp(op, index))) {
1007  emitError(baseLoc,
1008  "'spv.AccessChain' index must be an integer spv.Constant to "
1009  "access element of spv.struct, but provided ")
1010  << op->getName();
1011  return nullptr;
1012  }
1013  if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1014  emitError(baseLoc, "'spv.AccessChain' op index ")
1015  << index << " out of bounds for " << resultType;
1016  return nullptr;
1017  }
1018  }
1019  resultType = cType.getElementType(index);
1020  }
1021  return spirv::PointerType::get(resultType, resultStorageClass);
1022 }
1023 
1024 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1025  Value basePtr, ValueRange indices) {
1026  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1027  assert(type && "Unable to deduce return type based on basePtr and indices");
1028  build(builder, state, type, basePtr, indices);
1029 }
1030 
1031 ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1032  OperationState &state) {
1035  Type type;
1036  auto loc = parser.getCurrentLocation();
1037  SmallVector<Type, 4> indicesTypes;
1038 
1039  if (parser.parseOperand(ptrInfo) ||
1040  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1041  parser.parseColonType(type) ||
1042  parser.resolveOperand(ptrInfo, type, state.operands)) {
1043  return failure();
1044  }
1045 
1046  // Check that the provided indices list is not empty before parsing their
1047  // type list.
1048  if (indicesInfo.empty()) {
1049  return mlir::emitError(state.location, "'spv.AccessChain' op expected at "
1050  "least one index ");
1051  }
1052 
1053  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1054  return failure();
1055 
1056  // Check that the indices types list is not empty and that it has a one-to-one
1057  // mapping to the provided indices.
1058  if (indicesTypes.size() != indicesInfo.size()) {
1059  return mlir::emitError(state.location,
1060  "'spv.AccessChain' op indices types' count must be "
1061  "equal to indices info count");
1062  }
1063 
1064  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
1065  return failure();
1066 
1067  auto resultType = getElementPtrType(
1068  type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
1069  if (!resultType) {
1070  return failure();
1071  }
1072 
1073  state.addTypes(resultType);
1074  return success();
1075 }
1076 
1077 template <typename Op>
1078 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1079  printer << ' ' << op.base_ptr() << '[' << indices
1080  << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
1081 }
1082 
1084  printAccessChain(*this, indices(), printer);
1085 }
1086 
1087 template <typename Op>
1088 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1089  auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1090  indices, accessChainOp.getLoc());
1091  if (!resultType)
1092  return failure();
1093 
1094  auto providedResultType =
1095  accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1096  if (!providedResultType)
1097  return accessChainOp.emitOpError(
1098  "result type must be a pointer, but provided")
1099  << providedResultType;
1100 
1101  if (resultType != providedResultType)
1102  return accessChainOp.emitOpError("invalid result type: expected ")
1103  << resultType << ", but provided " << providedResultType;
1104 
1105  return success();
1106 }
1107 
1109  return verifyAccessChain(*this, indices());
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // spv.mlir.addressof
1114 //===----------------------------------------------------------------------===//
1115 
1116 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1117  spirv::GlobalVariableOp var) {
1118  build(builder, state, var.type(), SymbolRefAttr::get(var));
1119 }
1120 
1122  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1123  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1124  variableAttr()));
1125  if (!varOp) {
1126  return emitOpError("expected spv.GlobalVariable symbol");
1127  }
1128  if (pointer().getType() != varOp.type()) {
1129  return emitOpError(
1130  "result type mismatch with the referenced global variable's type");
1131  }
1132  return success();
1133 }
1134 
1135 template <typename T>
1136 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1137  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
1138  << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1139  << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1140  << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1141 }
1142 
1144  OperationState &state) {
1145  spirv::Scope memoryScope;
1146  spirv::MemorySemantics equalSemantics, unequalSemantics;
1148  Type type;
1149  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1150  parseEnumStrAttr(equalSemantics, parser, state,
1152  parseEnumStrAttr(unequalSemantics, parser, state,
1154  parser.parseOperandList(operandInfo, 3))
1155  return failure();
1156 
1157  auto loc = parser.getCurrentLocation();
1158  if (parser.parseColonType(type))
1159  return failure();
1160 
1161  auto ptrType = type.dyn_cast<spirv::PointerType>();
1162  if (!ptrType)
1163  return parser.emitError(loc, "expected pointer type");
1164 
1165  if (parser.resolveOperands(
1166  operandInfo,
1167  {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1168  parser.getNameLoc(), state.operands))
1169  return failure();
1170 
1171  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1172 }
1173 
1174 template <typename T>
1176  // According to the spec:
1177  // "The type of Value must be the same as Result Type. The type of the value
1178  // pointed to by Pointer must be the same as Result Type. This type must also
1179  // match the type of Comparator."
1180  if (atomOp.getType() != atomOp.value().getType())
1181  return atomOp.emitOpError("value operand must have the same type as the op "
1182  "result, but found ")
1183  << atomOp.value().getType() << " vs " << atomOp.getType();
1184 
1185  if (atomOp.getType() != atomOp.comparator().getType())
1186  return atomOp.emitOpError(
1187  "comparator operand must have the same type as the op "
1188  "result, but found ")
1189  << atomOp.comparator().getType() << " vs " << atomOp.getType();
1190 
1191  Type pointeeType = atomOp.pointer()
1192  .getType()
1193  .template cast<spirv::PointerType>()
1194  .getPointeeType();
1195  if (atomOp.getType() != pointeeType)
1196  return atomOp.emitOpError(
1197  "pointer operand's pointee type must have the same "
1198  "as the op result type, but found ")
1199  << pointeeType << " vs " << atomOp.getType();
1200 
1201  // TODO: Unequal cannot be set to Release or Acquire and Release.
1202  // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1203 
1204  return success();
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // spv.AtomicAndOp
1209 //===----------------------------------------------------------------------===//
1210 
1212  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1213 }
1214 
1215 ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1216  OperationState &result) {
1217  return ::parseAtomicUpdateOp(parser, result, true);
1218 }
1220  ::printAtomicUpdateOp(*this, p);
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // spv.AtomicCompareExchangeOp
1225 //===----------------------------------------------------------------------===//
1226 
1229 }
1230 
1231 ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1232  OperationState &result) {
1234 }
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // spv.AtomicCompareExchangeWeakOp
1241 //===----------------------------------------------------------------------===//
1242 
1245 }
1246 
1247 ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1248  OperationState &result) {
1250 }
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // spv.AtomicExchange
1257 //===----------------------------------------------------------------------===//
1258 
1260  printer << " \"" << stringifyScope(memory_scope()) << "\" \""
1261  << stringifyMemorySemantics(semantics()) << "\" " << getOperands()
1262  << " : " << pointer().getType();
1263 }
1264 
1265 ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1266  OperationState &state) {
1267  spirv::Scope memoryScope;
1268  spirv::MemorySemantics semantics;
1270  Type type;
1271  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1272  parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
1273  parser.parseOperandList(operandInfo, 2))
1274  return failure();
1275 
1276  auto loc = parser.getCurrentLocation();
1277  if (parser.parseColonType(type))
1278  return failure();
1279 
1280  auto ptrType = type.dyn_cast<spirv::PointerType>();
1281  if (!ptrType)
1282  return parser.emitError(loc, "expected pointer type");
1283 
1284  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1285  parser.getNameLoc(), state.operands))
1286  return failure();
1287 
1288  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1289 }
1290 
1292  if (getType() != value().getType())
1293  return emitOpError("value operand must have the same type as the op "
1294  "result, but found ")
1295  << value().getType() << " vs " << getType();
1296 
1297  Type pointeeType =
1298  pointer().getType().cast<spirv::PointerType>().getPointeeType();
1299  if (getType() != pointeeType)
1300  return emitOpError("pointer operand's pointee type must have the same "
1301  "as the op result type, but found ")
1302  << pointeeType << " vs " << getType();
1303 
1304  return success();
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // spv.AtomicIAddOp
1309 //===----------------------------------------------------------------------===//
1310 
1312  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1313 }
1314 
1315 ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1316  OperationState &result) {
1317  return ::parseAtomicUpdateOp(parser, result, true);
1318 }
1320  ::printAtomicUpdateOp(*this, p);
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // spv.AtomicFAddEXTOp
1325 //===----------------------------------------------------------------------===//
1326 
1328  return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1329 }
1330 
1331 ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
1332  OperationState &result) {
1333  return ::parseAtomicUpdateOp(parser, result, true);
1334 }
1336  ::printAtomicUpdateOp(*this, p);
1337 }
1338 
1339 //===----------------------------------------------------------------------===//
1340 // spv.AtomicIDecrementOp
1341 //===----------------------------------------------------------------------===//
1342 
1344  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1345 }
1346 
1347 ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1348  OperationState &result) {
1349  return ::parseAtomicUpdateOp(parser, result, false);
1350 }
1352  ::printAtomicUpdateOp(*this, p);
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // spv.AtomicIIncrementOp
1357 //===----------------------------------------------------------------------===//
1358 
1360  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1361 }
1362 
1363 ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1364  OperationState &result) {
1365  return ::parseAtomicUpdateOp(parser, result, false);
1366 }
1368  ::printAtomicUpdateOp(*this, p);
1369 }
1370 
1371 //===----------------------------------------------------------------------===//
1372 // spv.AtomicISubOp
1373 //===----------------------------------------------------------------------===//
1374 
1376  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1377 }
1378 
1379 ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1380  OperationState &result) {
1381  return ::parseAtomicUpdateOp(parser, result, true);
1382 }
1384  ::printAtomicUpdateOp(*this, p);
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // spv.AtomicOrOp
1389 //===----------------------------------------------------------------------===//
1390 
1392  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1393 }
1394 
1395 ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1396  OperationState &result) {
1397  return ::parseAtomicUpdateOp(parser, result, true);
1398 }
1400  ::printAtomicUpdateOp(*this, p);
1401 }
1402 
1403 //===----------------------------------------------------------------------===//
1404 // spv.AtomicSMaxOp
1405 //===----------------------------------------------------------------------===//
1406 
1408  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1409 }
1410 
1411 ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1412  OperationState &result) {
1413  return ::parseAtomicUpdateOp(parser, result, true);
1414 }
1416  ::printAtomicUpdateOp(*this, p);
1417 }
1418 
1419 //===----------------------------------------------------------------------===//
1420 // spv.AtomicSMinOp
1421 //===----------------------------------------------------------------------===//
1422 
1424  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1425 }
1426 
1427 ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1428  OperationState &result) {
1429  return ::parseAtomicUpdateOp(parser, result, true);
1430 }
1432  ::printAtomicUpdateOp(*this, p);
1433 }
1434 
1435 //===----------------------------------------------------------------------===//
1436 // spv.AtomicUMaxOp
1437 //===----------------------------------------------------------------------===//
1438 
1440  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1441 }
1442 
1443 ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1444  OperationState &result) {
1445  return ::parseAtomicUpdateOp(parser, result, true);
1446 }
1448  ::printAtomicUpdateOp(*this, p);
1449 }
1450 
1451 //===----------------------------------------------------------------------===//
1452 // spv.AtomicUMinOp
1453 //===----------------------------------------------------------------------===//
1454 
1456  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1457 }
1458 
1459 ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1460  OperationState &result) {
1461  return ::parseAtomicUpdateOp(parser, result, true);
1462 }
1464  ::printAtomicUpdateOp(*this, p);
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // spv.AtomicXorOp
1469 //===----------------------------------------------------------------------===//
1470 
1472  return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1473 }
1474 
1475 ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1476  OperationState &result) {
1477  return ::parseAtomicUpdateOp(parser, result, true);
1478 }
1480  ::printAtomicUpdateOp(*this, p);
1481 }
1482 
1483 //===----------------------------------------------------------------------===//
1484 // spv.BitcastOp
1485 //===----------------------------------------------------------------------===//
1486 
1488  // TODO: The SPIR-V spec validation rules are different for different
1489  // versions.
1490  auto operandType = operand().getType();
1491  auto resultType = result().getType();
1492  if (operandType == resultType) {
1493  return emitError("result type must be different from operand type");
1494  }
1495  if (operandType.isa<spirv::PointerType>() &&
1496  !resultType.isa<spirv::PointerType>()) {
1497  return emitError(
1498  "unhandled bit cast conversion from pointer type to non-pointer type");
1499  }
1500  if (!operandType.isa<spirv::PointerType>() &&
1501  resultType.isa<spirv::PointerType>()) {
1502  return emitError(
1503  "unhandled bit cast conversion from non-pointer type to pointer type");
1504  }
1505  auto operandBitWidth = getBitWidth(operandType);
1506  auto resultBitWidth = getBitWidth(resultType);
1507  if (operandBitWidth != resultBitWidth) {
1508  return emitOpError("mismatch in result type bitwidth ")
1509  << resultBitWidth << " and operand type bitwidth "
1510  << operandBitWidth;
1511  }
1512  return success();
1513 }
1514 
1515 //===----------------------------------------------------------------------===//
1516 // spv.BranchOp
1517 //===----------------------------------------------------------------------===//
1518 
1519 SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1520  assert(index == 0 && "invalid successor index");
1521  return SuccessorOperands(0, targetOperandsMutable());
1522 }
1523 
1524 //===----------------------------------------------------------------------===//
1525 // spv.BranchConditionalOp
1526 //===----------------------------------------------------------------------===//
1527 
1529 spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1530  assert(index < 2 && "invalid successor index");
1531  return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
1532  : falseTargetOperandsMutable());
1533 }
1534 
1535 ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1536  OperationState &state) {
1537  auto &builder = parser.getBuilder();
1539  Block *dest;
1540 
1541  // Parse the condition.
1542  Type boolTy = builder.getI1Type();
1543  if (parser.parseOperand(condInfo) ||
1544  parser.resolveOperand(condInfo, boolTy, state.operands))
1545  return failure();
1546 
1547  // Parse the optional branch weights.
1548  if (succeeded(parser.parseOptionalLSquare())) {
1549  IntegerAttr trueWeight, falseWeight;
1550  NamedAttrList weights;
1551 
1552  auto i32Type = builder.getIntegerType(32);
1553  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1554  parser.parseComma() ||
1555  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1556  parser.parseRSquare())
1557  return failure();
1558 
1560  builder.getArrayAttr({trueWeight, falseWeight}));
1561  }
1562 
1563  // Parse the true branch.
1564  SmallVector<Value, 4> trueOperands;
1565  if (parser.parseComma() ||
1566  parser.parseSuccessorAndUseList(dest, trueOperands))
1567  return failure();
1568  state.addSuccessors(dest);
1569  state.addOperands(trueOperands);
1570 
1571  // Parse the false branch.
1572  SmallVector<Value, 4> falseOperands;
1573  if (parser.parseComma() ||
1574  parser.parseSuccessorAndUseList(dest, falseOperands))
1575  return failure();
1576  state.addSuccessors(dest);
1577  state.addOperands(falseOperands);
1578  state.addAttribute(
1579  spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1580  builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1581  static_cast<int32_t>(falseOperands.size())}));
1582 
1583  return success();
1584 }
1585 
1587  printer << ' ' << condition();
1588 
1589  if (auto weights = branch_weights()) {
1590  printer << " [";
1591  llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1592  printer << a.cast<IntegerAttr>().getInt();
1593  });
1594  printer << "]";
1595  }
1596 
1597  printer << ", ";
1598  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1599  printer << ", ";
1600  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1601 }
1602 
1604  if (auto weights = branch_weights()) {
1605  if (weights->getValue().size() != 2) {
1606  return emitOpError("must have exactly two branch weights");
1607  }
1608  if (llvm::all_of(*weights, [](Attribute attr) {
1609  return attr.cast<IntegerAttr>().getValue().isNullValue();
1610  }))
1611  return emitOpError("branch weights cannot both be zero");
1612  }
1613 
1614  return success();
1615 }
1616 
1617 //===----------------------------------------------------------------------===//
1618 // spv.CompositeConstruct
1619 //===----------------------------------------------------------------------===//
1620 
1621 ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser,
1622  OperationState &state) {
1624  Type type;
1625  auto loc = parser.getCurrentLocation();
1626 
1627  if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1628  return failure();
1629  }
1630  auto cType = type.dyn_cast<spirv::CompositeType>();
1631  if (!cType) {
1632  return parser.emitError(
1633  loc, "result type must be a composite type, but provided ")
1634  << type;
1635  }
1636 
1637  if (cType.hasCompileTimeKnownNumElements() &&
1638  operands.size() != cType.getNumElements()) {
1639  return parser.emitError(loc, "has incorrect number of operands: expected ")
1640  << cType.getNumElements() << ", but provided " << operands.size();
1641  }
1642  // TODO: Add support for constructing a vector type from the vector operands.
1643  // According to the spec: "for constructing a vector, the operands may
1644  // also be vectors with the same component type as the Result Type component
1645  // type".
1646  SmallVector<Type, 4> elementTypes;
1647  elementTypes.reserve(operands.size());
1648  for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1649  elementTypes.push_back(cType.getElementType(index));
1650  }
1651  state.addTypes(type);
1652  return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1653 }
1654 
1656  printer << " " << constituents() << " : " << getResult().getType();
1657 }
1658 
1660  auto cType = getType().cast<spirv::CompositeType>();
1661  operand_range constituents = this->constituents();
1662 
1663  if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1664  if (constituents.size() != 1)
1665  return emitError("has incorrect number of operands: expected ")
1666  << "1, but provided " << constituents.size();
1667  } else if (constituents.size() != cType.getNumElements()) {
1668  return emitError("has incorrect number of operands: expected ")
1669  << cType.getNumElements() << ", but provided "
1670  << constituents.size();
1671  }
1672 
1673  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1674  if (constituents[index].getType() != cType.getElementType(index)) {
1675  return emitError("operand type mismatch: expected operand type ")
1676  << cType.getElementType(index) << ", but provided "
1677  << constituents[index].getType();
1678  }
1679  }
1680 
1681  return success();
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // spv.CompositeExtractOp
1686 //===----------------------------------------------------------------------===//
1687 
1688 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1689  Value composite,
1690  ArrayRef<int32_t> indices) {
1691  auto indexAttr = builder.getI32ArrayAttr(indices);
1692  auto elementType =
1693  getElementType(composite.getType(), indexAttr, state.location);
1694  if (!elementType) {
1695  return;
1696  }
1697  build(builder, state, elementType, composite, indexAttr);
1698 }
1699 
1700 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1701  OperationState &state) {
1702  OpAsmParser::UnresolvedOperand compositeInfo;
1703  Attribute indicesAttr;
1704  Type compositeType;
1705  SMLoc attrLocation;
1706 
1707  if (parser.parseOperand(compositeInfo) ||
1708  parser.getCurrentLocation(&attrLocation) ||
1709  parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1710  parser.parseColonType(compositeType) ||
1711  parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1712  return failure();
1713  }
1714 
1715  Type resultType =
1716  getElementType(compositeType, indicesAttr, parser, attrLocation);
1717  if (!resultType) {
1718  return failure();
1719  }
1720  state.addTypes(resultType);
1721  return success();
1722 }
1723 
1725  printer << ' ' << composite() << indices() << " : " << composite().getType();
1726 }
1727 
1729  auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1730  auto resultType =
1731  getElementType(composite().getType(), indicesArrayAttr, getLoc());
1732  if (!resultType)
1733  return failure();
1734 
1735  if (resultType != getType()) {
1736  return emitOpError("invalid result type: expected ")
1737  << resultType << " but provided " << getType();
1738  }
1739 
1740  return success();
1741 }
1742 
1743 //===----------------------------------------------------------------------===//
1744 // spv.CompositeInsert
1745 //===----------------------------------------------------------------------===//
1746 
1747 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1748  Value object, Value composite,
1749  ArrayRef<int32_t> indices) {
1750  auto indexAttr = builder.getI32ArrayAttr(indices);
1751  build(builder, state, composite.getType(), object, composite, indexAttr);
1752 }
1753 
1754 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1755  OperationState &state) {
1757  Type objectType, compositeType;
1758  Attribute indicesAttr;
1759  auto loc = parser.getCurrentLocation();
1760 
1761  return failure(
1762  parser.parseOperandList(operands, 2) ||
1763  parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1764  parser.parseColonType(objectType) ||
1765  parser.parseKeywordType("into", compositeType) ||
1766  parser.resolveOperands(operands, {objectType, compositeType}, loc,
1767  state.operands) ||
1768  parser.addTypesToList(compositeType, state.types));
1769 }
1770 
1772  auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1773  auto objectType =
1774  getElementType(composite().getType(), indicesArrayAttr, getLoc());
1775  if (!objectType)
1776  return failure();
1777 
1778  if (objectType != object().getType()) {
1779  return emitOpError("object operand type should be ")
1780  << objectType << ", but found " << object().getType();
1781  }
1782 
1783  if (composite().getType() != getType()) {
1784  return emitOpError("result type should be the same as "
1785  "the composite type, but found ")
1786  << composite().getType() << " vs " << getType();
1787  }
1788 
1789  return success();
1790 }
1791 
1793  printer << " " << object() << ", " << composite() << indices() << " : "
1794  << object().getType() << " into " << composite().getType();
1795 }
1796 
1797 //===----------------------------------------------------------------------===//
1798 // spv.Constant
1799 //===----------------------------------------------------------------------===//
1800 
1801 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1802  OperationState &state) {
1803  Attribute value;
1804  if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1805  return failure();
1806 
1807  Type type = value.getType();
1808  if (type.isa<NoneType, TensorType>()) {
1809  if (parser.parseColonType(type))
1810  return failure();
1811  }
1812 
1813  return parser.addTypeToList(type, state.types);
1814 }
1815 
1816 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1817  printer << ' ' << value();
1818  if (getType().isa<spirv::ArrayType>())
1819  printer << " : " << getType();
1820 }
1821 
1822 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1823  Type opType) {
1824  auto valueType = value.getType();
1825 
1826  if (value.isa<IntegerAttr, FloatAttr>()) {
1827  if (valueType != opType)
1828  return op.emitOpError("result type (")
1829  << opType << ") does not match value type (" << valueType << ")";
1830  return success();
1831  }
1832  if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1833  if (valueType == opType)
1834  return success();
1835  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1836  auto shapedType = valueType.dyn_cast<ShapedType>();
1837  if (!arrayType)
1838  return op.emitOpError("result or element type (")
1839  << opType << ") does not match value type (" << valueType
1840  << "), must be the same or spv.array";
1841 
1842  int numElements = arrayType.getNumElements();
1843  auto opElemType = arrayType.getElementType();
1844  while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1845  numElements *= t.getNumElements();
1846  opElemType = t.getElementType();
1847  }
1848  if (!opElemType.isIntOrFloat())
1849  return op.emitOpError("only support nested array result type");
1850 
1851  auto valueElemType = shapedType.getElementType();
1852  if (valueElemType != opElemType) {
1853  return op.emitOpError("result element type (")
1854  << opElemType << ") does not match value element type ("
1855  << valueElemType << ")";
1856  }
1857 
1858  if (numElements != shapedType.getNumElements()) {
1859  return op.emitOpError("result number of elements (")
1860  << numElements << ") does not match value number of elements ("
1861  << shapedType.getNumElements() << ")";
1862  }
1863  return success();
1864  }
1865  if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
1866  auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1867  if (!arrayType)
1868  return op.emitOpError("must have spv.array result type for array value");
1869  Type elemType = arrayType.getElementType();
1870  for (Attribute element : arrayAttr.getValue()) {
1871  // Verify array elements recursively.
1872  if (failed(verifyConstantType(op, element, elemType)))
1873  return failure();
1874  }
1875  return success();
1876  }
1877  return op.emitOpError("cannot have value of type ") << valueType;
1878 }
1879 
1881  // ODS already generates checks to make sure the result type is valid. We just
1882  // need to additionally check that the value's attribute type is consistent
1883  // with the result type.
1884  return verifyConstantType(*this, valueAttr(), getType());
1885 }
1886 
1887 bool spirv::ConstantOp::isBuildableWith(Type type) {
1888  // Must be valid SPIR-V type first.
1889  if (!type.isa<spirv::SPIRVType>())
1890  return false;
1891 
1892  if (isa<SPIRVDialect>(type.getDialect())) {
1893  // TODO: support constant struct
1894  return type.isa<spirv::ArrayType>();
1895  }
1896 
1897  return true;
1898 }
1899 
1900 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1901  OpBuilder &builder) {
1902  if (auto intType = type.dyn_cast<IntegerType>()) {
1903  unsigned width = intType.getWidth();
1904  if (width == 1)
1905  return builder.create<spirv::ConstantOp>(loc, type,
1906  builder.getBoolAttr(false));
1907  return builder.create<spirv::ConstantOp>(
1908  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1909  }
1910  if (auto floatType = type.dyn_cast<FloatType>()) {
1911  return builder.create<spirv::ConstantOp>(
1912  loc, type, builder.getFloatAttr(floatType, 0.0));
1913  }
1914  if (auto vectorType = type.dyn_cast<VectorType>()) {
1915  Type elemType = vectorType.getElementType();
1916  if (elemType.isa<IntegerType>()) {
1917  return builder.create<spirv::ConstantOp>(
1918  loc, type,
1920  IntegerAttr::get(elemType, 0.0).getValue()));
1921  }
1922  if (elemType.isa<FloatType>()) {
1923  return builder.create<spirv::ConstantOp>(
1924  loc, type,
1926  FloatAttr::get(elemType, 0.0).getValue()));
1927  }
1928  }
1929 
1930  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1931 }
1932 
1933 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1934  OpBuilder &builder) {
1935  if (auto intType = type.dyn_cast<IntegerType>()) {
1936  unsigned width = intType.getWidth();
1937  if (width == 1)
1938  return builder.create<spirv::ConstantOp>(loc, type,
1939  builder.getBoolAttr(true));
1940  return builder.create<spirv::ConstantOp>(
1941  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1942  }
1943  if (auto floatType = type.dyn_cast<FloatType>()) {
1944  return builder.create<spirv::ConstantOp>(
1945  loc, type, builder.getFloatAttr(floatType, 1.0));
1946  }
1947  if (auto vectorType = type.dyn_cast<VectorType>()) {
1948  Type elemType = vectorType.getElementType();
1949  if (elemType.isa<IntegerType>()) {
1950  return builder.create<spirv::ConstantOp>(
1951  loc, type,
1953  IntegerAttr::get(elemType, 1.0).getValue()));
1954  }
1955  if (elemType.isa<FloatType>()) {
1956  return builder.create<spirv::ConstantOp>(
1957  loc, type,
1959  FloatAttr::get(elemType, 1.0).getValue()));
1960  }
1961  }
1962 
1963  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1964 }
1965 
1966 void mlir::spirv::ConstantOp::getAsmResultNames(
1967  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1968  Type type = getType();
1969 
1970  SmallString<32> specialNameBuffer;
1971  llvm::raw_svector_ostream specialName(specialNameBuffer);
1972  specialName << "cst";
1973 
1974  IntegerType intTy = type.dyn_cast<IntegerType>();
1975 
1976  if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1977  if (intTy && intTy.getWidth() == 1) {
1978  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1979  }
1980 
1981  if (intTy.isSignless()) {
1982  specialName << intCst.getInt();
1983  } else {
1984  specialName << intCst.getSInt();
1985  }
1986  }
1987 
1988  if (intTy || type.isa<FloatType>()) {
1989  specialName << '_' << type;
1990  }
1991 
1992  if (auto vecType = type.dyn_cast<VectorType>()) {
1993  specialName << "_vec_";
1994  specialName << vecType.getDimSize(0);
1995 
1996  Type elementType = vecType.getElementType();
1997 
1998  if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1999  specialName << "x" << elementType;
2000  }
2001  }
2002 
2003  setNameFn(getResult(), specialName.str());
2004 }
2005 
2006 void mlir::spirv::AddressOfOp::getAsmResultNames(
2007  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2008  SmallString<32> specialNameBuffer;
2009  llvm::raw_svector_ostream specialName(specialNameBuffer);
2010  specialName << variable() << "_addr";
2011  setNameFn(getResult(), specialName.str());
2012 }
2013 
2014 //===----------------------------------------------------------------------===//
2015 // spv.ControlBarrierOp
2016 //===----------------------------------------------------------------------===//
2017 
2019  return verifyMemorySemantics(getOperation(), memory_semantics());
2020 }
2021 
2022 //===----------------------------------------------------------------------===//
2023 // spv.ConvertFToSOp
2024 //===----------------------------------------------------------------------===//
2025 
2027  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2028  /*skipBitWidthCheck=*/true);
2029 }
2030 
2031 //===----------------------------------------------------------------------===//
2032 // spv.ConvertFToUOp
2033 //===----------------------------------------------------------------------===//
2034 
2036  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2037  /*skipBitWidthCheck=*/true);
2038 }
2039 
2040 //===----------------------------------------------------------------------===//
2041 // spv.ConvertSToFOp
2042 //===----------------------------------------------------------------------===//
2043 
2045  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2046  /*skipBitWidthCheck=*/true);
2047 }
2048 
2049 //===----------------------------------------------------------------------===//
2050 // spv.ConvertUToFOp
2051 //===----------------------------------------------------------------------===//
2052 
2054  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2055  /*skipBitWidthCheck=*/true);
2056 }
2057 
2058 //===----------------------------------------------------------------------===//
2059 // spv.EntryPoint
2060 //===----------------------------------------------------------------------===//
2061 
2062 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2063  spirv::ExecutionModel executionModel,
2064  spirv::FuncOp function,
2065  ArrayRef<Attribute> interfaceVars) {
2066  build(builder, state,
2067  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2068  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2069 }
2070 
2071 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2072  OperationState &state) {
2073  spirv::ExecutionModel execModel;
2075  SmallVector<Type, 0> idTypes;
2076  SmallVector<Attribute, 4> interfaceVars;
2077 
2078  FlatSymbolRefAttr fn;
2079  if (parseEnumStrAttr(execModel, parser, state) ||
2080  parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
2081  return failure();
2082  }
2083 
2084  if (!parser.parseOptionalComma()) {
2085  // Parse the interface variables
2086  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2087  // The name of the interface variable attribute isnt important
2088  FlatSymbolRefAttr var;
2089  NamedAttrList attrs;
2090  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2091  return failure();
2092  interfaceVars.push_back(var);
2093  return success();
2094  }))
2095  return failure();
2096  }
2098  parser.getBuilder().getArrayAttr(interfaceVars));
2099  return success();
2100 }
2101 
2103  printer << " \"" << stringifyExecutionModel(execution_model()) << "\" ";
2104  printer.printSymbolName(fn());
2105  auto interfaceVars = interface().getValue();
2106  if (!interfaceVars.empty()) {
2107  printer << ", ";
2108  llvm::interleaveComma(interfaceVars, printer);
2109  }
2110 }
2111 
2113  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2114  // verification.
2115  return success();
2116 }
2117 
2118 //===----------------------------------------------------------------------===//
2119 // spv.ExecutionMode
2120 //===----------------------------------------------------------------------===//
2121 
2122 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2123  spirv::FuncOp function,
2124  spirv::ExecutionMode executionMode,
2125  ArrayRef<int32_t> params) {
2126  build(builder, state, SymbolRefAttr::get(function),
2127  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2128  builder.getI32ArrayAttr(params));
2129 }
2130 
2131 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2132  OperationState &state) {
2133  spirv::ExecutionMode execMode;
2134  Attribute fn;
2135  if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
2136  parseEnumStrAttr(execMode, parser, state)) {
2137  return failure();
2138  }
2139 
2140  SmallVector<int32_t, 4> values;
2141  Type i32Type = parser.getBuilder().getIntegerType(32);
2142  while (!parser.parseOptionalComma()) {
2143  NamedAttrList attr;
2144  Attribute value;
2145  if (parser.parseAttribute(value, i32Type, "value", attr)) {
2146  return failure();
2147  }
2148  values.push_back(value.cast<IntegerAttr>().getInt());
2149  }
2151  parser.getBuilder().getI32ArrayAttr(values));
2152  return success();
2153 }
2154 
2156  printer << " ";
2157  printer.printSymbolName(fn());
2158  printer << " \"" << stringifyExecutionMode(execution_mode()) << "\"";
2159  auto values = this->values();
2160  if (values.empty())
2161  return;
2162  printer << ", ";
2163  llvm::interleaveComma(values, printer, [&](Attribute a) {
2164  printer << a.cast<IntegerAttr>().getInt();
2165  });
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // spv.FConvertOp
2170 //===----------------------------------------------------------------------===//
2171 
2173  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2174 }
2175 
2176 //===----------------------------------------------------------------------===//
2177 // spv.SConvertOp
2178 //===----------------------------------------------------------------------===//
2179 
2181  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2182 }
2183 
2184 //===----------------------------------------------------------------------===//
2185 // spv.UConvertOp
2186 //===----------------------------------------------------------------------===//
2187 
2189  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2190 }
2191 
2192 //===----------------------------------------------------------------------===//
2193 // spv.func
2194 //===----------------------------------------------------------------------===//
2195 
2196 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2198  SmallVector<DictionaryAttr> resultAttrs;
2199  SmallVector<Type> resultTypes;
2200  auto &builder = parser.getBuilder();
2201 
2202  // Parse the name as a symbol.
2203  StringAttr nameAttr;
2204  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2205  state.attributes))
2206  return failure();
2207 
2208  // Parse the function signature.
2209  bool isVariadic = false;
2211  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2212  resultAttrs))
2213  return failure();
2214 
2215  SmallVector<Type> argTypes;
2216  for (auto &arg : entryArgs)
2217  argTypes.push_back(arg.type);
2218  auto fnType = builder.getFunctionType(argTypes, resultTypes);
2220  TypeAttr::get(fnType));
2221 
2222  // Parse the optional function control keyword.
2223  spirv::FunctionControl fnControl;
2224  if (parseEnumStrAttr(fnControl, parser, state))
2225  return failure();
2226 
2227  // If additional attributes are present, parse them.
2228  if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2229  return failure();
2230 
2231  // Add the attributes to the function arguments.
2232  assert(resultAttrs.size() == resultTypes.size());
2233  function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs,
2234  resultAttrs);
2235 
2236  // Parse the optional function body.
2237  auto *body = state.addRegion();
2238  OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs);
2239  return failure(result.hasValue() && failed(*result));
2240 }
2241 
2242 void spirv::FuncOp::print(OpAsmPrinter &printer) {
2243  // Print function name, signature, and control.
2244  printer << " ";
2245  printer.printSymbolName(sym_name());
2246  auto fnType = getFunctionType();
2248  printer, *this, fnType.getInputs(),
2249  /*isVariadic=*/false, fnType.getResults());
2250  printer << " \"" << spirv::stringifyFunctionControl(function_control())
2251  << "\"";
2253  printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
2254  {spirv::attributeName<spirv::FunctionControl>()});
2255 
2256  // Print the body if this is not an external function.
2257  Region &body = this->body();
2258  if (!body.empty()) {
2259  printer << ' ';
2260  printer.printRegion(body, /*printEntryBlockArgs=*/false,
2261  /*printBlockTerminators=*/true);
2262  }
2263 }
2264 
2265 LogicalResult spirv::FuncOp::verifyType() {
2266  auto type = getFunctionTypeAttr().getValue();
2267  if (!type.isa<FunctionType>())
2268  return emitOpError("requires '" + getTypeAttrName() +
2269  "' attribute of function type");
2270  if (getFunctionType().getNumResults() > 1)
2271  return emitOpError("cannot have more than one result");
2272  return success();
2273 }
2274 
2275 LogicalResult spirv::FuncOp::verifyBody() {
2276  FunctionType fnType = getFunctionType();
2277 
2278  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2279  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2280  if (fnType.getNumResults() != 0)
2281  return retOp.emitOpError("cannot be used in functions returning value");
2282  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2283  if (fnType.getNumResults() != 1)
2284  return retOp.emitOpError(
2285  "returns 1 value but enclosing function requires ")
2286  << fnType.getNumResults() << " results";
2287 
2288  auto retOperandType = retOp.value().getType();
2289  auto fnResultType = fnType.getResult(0);
2290  if (retOperandType != fnResultType)
2291  return retOp.emitOpError(" return value's type (")
2292  << retOperandType << ") mismatch with function's result type ("
2293  << fnResultType << ")";
2294  }
2295  return WalkResult::advance();
2296  });
2297 
2298  // TODO: verify other bits like linkage type.
2299 
2300  return failure(walkResult.wasInterrupted());
2301 }
2302 
2303 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2304  StringRef name, FunctionType type,
2305  spirv::FunctionControl control,
2306  ArrayRef<NamedAttribute> attrs) {
2308  builder.getStringAttr(name));
2309  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2310  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2311  builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
2312  state.attributes.append(attrs.begin(), attrs.end());
2313  state.addRegion();
2314 }
2315 
2316 // CallableOpInterface
2317 Region *spirv::FuncOp::getCallableRegion() {
2318  return isExternal() ? nullptr : &body();
2319 }
2320 
2321 // CallableOpInterface
2322 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2323  return getFunctionType().getResults();
2324 }
2325 
2326 //===----------------------------------------------------------------------===//
2327 // spv.FunctionCall
2328 //===----------------------------------------------------------------------===//
2329 
2331  auto fnName = calleeAttr();
2332 
2333  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2334  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2335  if (!funcOp) {
2336  return emitOpError("callee function '")
2337  << fnName.getValue() << "' not found in nearest symbol table";
2338  }
2339 
2340  auto functionType = funcOp.getFunctionType();
2341 
2342  if (getNumResults() > 1) {
2343  return emitOpError(
2344  "expected callee function to have 0 or 1 result, but provided ")
2345  << getNumResults();
2346  }
2347 
2348  if (functionType.getNumInputs() != getNumOperands()) {
2349  return emitOpError("has incorrect number of operands for callee: expected ")
2350  << functionType.getNumInputs() << ", but provided "
2351  << getNumOperands();
2352  }
2353 
2354  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2355  if (getOperand(i).getType() != functionType.getInput(i)) {
2356  return emitOpError("operand type mismatch: expected operand type ")
2357  << functionType.getInput(i) << ", but provided "
2358  << getOperand(i).getType() << " for operand number " << i;
2359  }
2360  }
2361 
2362  if (functionType.getNumResults() != getNumResults()) {
2363  return emitOpError(
2364  "has incorrect number of results has for callee: expected ")
2365  << functionType.getNumResults() << ", but provided "
2366  << getNumResults();
2367  }
2368 
2369  if (getNumResults() &&
2370  (getResult(0).getType() != functionType.getResult(0))) {
2371  return emitOpError("result type mismatch: expected ")
2372  << functionType.getResult(0) << ", but provided "
2373  << getResult(0).getType();
2374  }
2375 
2376  return success();
2377 }
2378 
2379 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2380  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2381 }
2382 
2383 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2384  return arguments();
2385 }
2386 
2387 //===----------------------------------------------------------------------===//
2388 // spv.GLSLFClampOp
2389 //===----------------------------------------------------------------------===//
2390 
2391 ParseResult spirv::GLSLFClampOp::parse(OpAsmParser &parser,
2392  OperationState &result) {
2393  return parseOneResultSameOperandTypeOp(parser, result);
2394 }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // spv.GLSLUClampOp
2399 //===----------------------------------------------------------------------===//
2400 
2401 ParseResult spirv::GLSLUClampOp::parse(OpAsmParser &parser,
2402  OperationState &result) {
2403  return parseOneResultSameOperandTypeOp(parser, result);
2404 }
2406 
2407 //===----------------------------------------------------------------------===//
2408 // spv.GLSLSClampOp
2409 //===----------------------------------------------------------------------===//
2410 
2411 ParseResult spirv::GLSLSClampOp::parse(OpAsmParser &parser,
2412  OperationState &result) {
2413  return parseOneResultSameOperandTypeOp(parser, result);
2414 }
2416 
2417 //===----------------------------------------------------------------------===//
2418 // spv.GLSLFmaOp
2419 //===----------------------------------------------------------------------===//
2420 
2421 ParseResult spirv::GLSLFmaOp::parse(OpAsmParser &parser,
2422  OperationState &result) {
2423  return parseOneResultSameOperandTypeOp(parser, result);
2424 }
2426 
2427 //===----------------------------------------------------------------------===//
2428 // spv.GlobalVariable
2429 //===----------------------------------------------------------------------===//
2430 
2431 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2432  Type type, StringRef name,
2433  unsigned descriptorSet, unsigned binding) {
2434  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2435  state.addAttribute(
2436  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2437  builder.getI32IntegerAttr(descriptorSet));
2438  state.addAttribute(
2439  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2440  builder.getI32IntegerAttr(binding));
2441 }
2442 
2443 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2444  Type type, StringRef name,
2445  spirv::BuiltIn builtin) {
2446  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2447  state.addAttribute(
2448  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2449  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2450 }
2451 
2452 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2453  OperationState &state) {
2454  // Parse variable name.
2455  StringAttr nameAttr;
2456  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2457  state.attributes)) {
2458  return failure();
2459  }
2460 
2461  // Parse optional initializer
2463  FlatSymbolRefAttr initSymbol;
2464  if (parser.parseLParen() ||
2465  parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2466  state.attributes) ||
2467  parser.parseRParen())
2468  return failure();
2469  }
2470 
2471  if (parseVariableDecorations(parser, state)) {
2472  return failure();
2473  }
2474 
2475  Type type;
2476  auto loc = parser.getCurrentLocation();
2477  if (parser.parseColonType(type)) {
2478  return failure();
2479  }
2480  if (!type.isa<spirv::PointerType>()) {
2481  return parser.emitError(loc, "expected spv.ptr type");
2482  }
2483  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
2484 
2485  return success();
2486 }
2487 
2489  SmallVector<StringRef, 4> elidedAttrs{
2490  spirv::attributeName<spirv::StorageClass>()};
2491 
2492  // Print variable name.
2493  printer << ' ';
2494  printer.printSymbolName(sym_name());
2495  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2496 
2497  // Print optional initializer
2498  if (auto initializer = this->initializer()) {
2499  printer << " " << kInitializerAttrName << '(';
2500  printer.printSymbolName(*initializer);
2501  printer << ')';
2502  elidedAttrs.push_back(kInitializerAttrName);
2503  }
2504 
2505  elidedAttrs.push_back(kTypeAttrName);
2506  printVariableDecorations(*this, printer, elidedAttrs);
2507  printer << " : " << type();
2508 }
2509 
2511  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2512  // object. It cannot be Generic. It must be the same as the Storage Class
2513  // operand of the Result Type."
2514  // Also, Function storage class is reserved by spv.Variable.
2515  auto storageClass = this->storageClass();
2516  if (storageClass == spirv::StorageClass::Generic ||
2517  storageClass == spirv::StorageClass::Function) {
2518  return emitOpError("storage class cannot be '")
2519  << stringifyStorageClass(storageClass) << "'";
2520  }
2521 
2522  if (auto init =
2523  (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2525  (*this)->getParentOp(), init.getAttr());
2526  // TODO: Currently only variable initialization with specialization
2527  // constants and other variables is supported. They could be normal
2528  // constants in the module scope as well.
2529  if (!initOp ||
2530  !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2531  return emitOpError("initializer must be result of a "
2532  "spv.SpecConstant or spv.GlobalVariable op");
2533  }
2534  }
2535 
2536  return success();
2537 }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // spv.GroupBroadcast
2541 //===----------------------------------------------------------------------===//
2542 
2544  spirv::Scope scope = execution_scope();
2545  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2546  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2547 
2548  if (auto localIdTy = localid().getType().dyn_cast<VectorType>())
2549  if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2550  return emitOpError("localid is a vector and can be with only "
2551  " 2 or 3 components, actual number is ")
2552  << localIdTy.getNumElements();
2553 
2554  return success();
2555 }
2556 
2557 //===----------------------------------------------------------------------===//
2558 // spv.GroupNonUniformBallotOp
2559 //===----------------------------------------------------------------------===//
2560 
2562  spirv::Scope scope = execution_scope();
2563  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2564  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2565 
2566  return success();
2567 }
2568 
2569 //===----------------------------------------------------------------------===//
2570 // spv.GroupNonUniformBroadcast
2571 //===----------------------------------------------------------------------===//
2572 
2574  spirv::Scope scope = execution_scope();
2575  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2576  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2577 
2578  // SPIR-V spec: "Before version 1.5, Id must come from a
2579  // constant instruction.
2580  auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2581  if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2582  targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2583 
2584  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2585  auto *idOp = id().getDefiningOp();
2586  if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2587  spirv::ReferenceOfOp>(idOp)) // for spec constant
2588  return emitOpError("id must be the result of a constant op");
2589  }
2590 
2591  return success();
2592 }
2593 
2594 //===----------------------------------------------------------------------===//
2595 // spv.SubgroupBlockReadINTEL
2596 //===----------------------------------------------------------------------===//
2597 
2598 ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
2599  OperationState &state) {
2600  // Parse the storage class specification
2601  spirv::StorageClass storageClass;
2603  Type elementType;
2604  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2605  parser.parseColon() || parser.parseType(elementType)) {
2606  return failure();
2607  }
2608 
2609  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2610  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2611  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2612 
2613  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2614  return failure();
2615  }
2616 
2617  state.addTypes(elementType);
2618  return success();
2619 }
2620 
2622  printer << " " << ptr() << " : " << getType();
2623 }
2624 
2626  if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2627  return failure();
2628 
2629  return success();
2630 }
2631 
2632 //===----------------------------------------------------------------------===//
2633 // spv.SubgroupBlockWriteINTEL
2634 //===----------------------------------------------------------------------===//
2635 
2636 ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
2637  OperationState &state) {
2638  // Parse the storage class specification
2639  spirv::StorageClass storageClass;
2641  auto loc = parser.getCurrentLocation();
2642  Type elementType;
2643  if (parseEnumStrAttr(storageClass, parser) ||
2644  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2645  parser.parseType(elementType)) {
2646  return failure();
2647  }
2648 
2649  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2650  if (auto valVecTy = elementType.dyn_cast<VectorType>())
2651  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2652 
2653  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2654  state.operands)) {
2655  return failure();
2656  }
2657  return success();
2658 }
2659 
2661  printer << " " << ptr() << ", " << value() << " : " << value().getType();
2662 }
2663 
2665  if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2666  return failure();
2667 
2668  return success();
2669 }
2670 
2671 //===----------------------------------------------------------------------===//
2672 // spv.GroupNonUniformElectOp
2673 //===----------------------------------------------------------------------===//
2674 
2676  spirv::Scope scope = execution_scope();
2677  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2678  return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2679 
2680  return success();
2681 }
2682 
2683 //===----------------------------------------------------------------------===//
2684 // spv.GroupNonUniformFAddOp
2685 //===----------------------------------------------------------------------===//
2686 
2688  return verifyGroupNonUniformArithmeticOp(*this);
2689 }
2690 
2691 ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2692  OperationState &result) {
2693  return parseGroupNonUniformArithmeticOp(parser, result);
2694 }
2697 }
2698 
2699 //===----------------------------------------------------------------------===//
2700 // spv.GroupNonUniformFMaxOp
2701 //===----------------------------------------------------------------------===//
2702 
2704  return verifyGroupNonUniformArithmeticOp(*this);
2705 }
2706 
2707 ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2708  OperationState &result) {
2709  return parseGroupNonUniformArithmeticOp(parser, result);
2710 }
2713 }
2714 
2715 //===----------------------------------------------------------------------===//
2716 // spv.GroupNonUniformFMinOp
2717 //===----------------------------------------------------------------------===//
2718 
2720  return verifyGroupNonUniformArithmeticOp(*this);
2721 }
2722 
2723 ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2724  OperationState &result) {
2725  return parseGroupNonUniformArithmeticOp(parser, result);
2726 }
2729 }
2730 
2731 //===----------------------------------------------------------------------===//
2732 // spv.GroupNonUniformFMulOp
2733 //===----------------------------------------------------------------------===//
2734 
2736  return verifyGroupNonUniformArithmeticOp(*this);
2737 }
2738 
2739 ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2740  OperationState &result) {
2741  return parseGroupNonUniformArithmeticOp(parser, result);
2742 }
2745 }
2746 
2747 //===----------------------------------------------------------------------===//
2748 // spv.GroupNonUniformIAddOp
2749 //===----------------------------------------------------------------------===//
2750 
2752  return verifyGroupNonUniformArithmeticOp(*this);
2753 }
2754 
2755 ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2756  OperationState &result) {
2757  return parseGroupNonUniformArithmeticOp(parser, result);
2758 }
2761 }
2762 
2763 //===----------------------------------------------------------------------===//
2764 // spv.GroupNonUniformIMulOp
2765 //===----------------------------------------------------------------------===//
2766 
2768  return verifyGroupNonUniformArithmeticOp(*this);
2769 }
2770 
2771 ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2772  OperationState &result) {
2773  return parseGroupNonUniformArithmeticOp(parser, result);
2774 }
2777 }
2778 
2779 //===----------------------------------------------------------------------===//
2780 // spv.GroupNonUniformSMaxOp
2781 //===----------------------------------------------------------------------===//
2782 
2784  return verifyGroupNonUniformArithmeticOp(*this);
2785 }
2786 
2787 ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2788  OperationState &result) {
2789  return parseGroupNonUniformArithmeticOp(parser, result);
2790 }
2793 }
2794 
2795 //===----------------------------------------------------------------------===//
2796 // spv.GroupNonUniformSMinOp
2797 //===----------------------------------------------------------------------===//
2798 
2800  return verifyGroupNonUniformArithmeticOp(*this);
2801 }
2802 
2803 ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
2804  OperationState &result) {
2805  return parseGroupNonUniformArithmeticOp(parser, result);
2806 }
2809 }
2810 
2811 //===----------------------------------------------------------------------===//
2812 // spv.GroupNonUniformUMaxOp
2813 //===----------------------------------------------------------------------===//
2814 
2816  return verifyGroupNonUniformArithmeticOp(*this);
2817 }
2818 
2819 ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
2820  OperationState &result) {
2821  return parseGroupNonUniformArithmeticOp(parser, result);
2822 }
2825 }
2826 
2827 //===----------------------------------------------------------------------===//
2828 // spv.GroupNonUniformUMinOp
2829 //===----------------------------------------------------------------------===//
2830 
2832  return verifyGroupNonUniformArithmeticOp(*this);
2833 }
2834 
2835 ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
2836  OperationState &result) {
2837  return parseGroupNonUniformArithmeticOp(parser, result);
2838 }
2841 }
2842 
2843 //===----------------------------------------------------------------------===//
2844 // spv.ISubBorrowOp
2845 //===----------------------------------------------------------------------===//
2846 
2848  auto resultType = getType().cast<spirv::StructType>();
2849  if (resultType.getNumElements() != 2)
2850  return emitOpError("expected result struct type containing two members");
2851 
2852  SmallVector<Type, 4> types;
2853  types.push_back(operand1().getType());
2854  types.push_back(operand2().getType());
2855  types.push_back(resultType.getElementType(0));
2856  types.push_back(resultType.getElementType(1));
2857  if (!llvm::is_splat(types))
2858  return emitOpError(
2859  "expected all operand types and struct member types are the same");
2860 
2861  return success();
2862 }
2863 
2864 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
2865  OperationState &state) {
2867  if (parser.parseOptionalAttrDict(state.attributes) ||
2868  parser.parseOperandList(operands) || parser.parseColon())
2869  return failure();
2870 
2871  Type resultType;
2872  auto loc = parser.getCurrentLocation();
2873  if (parser.parseType(resultType))
2874  return failure();
2875 
2876  auto structType = resultType.dyn_cast<spirv::StructType>();
2877  if (!structType || structType.getNumElements() != 2)
2878  return parser.emitError(loc, "expected spv.struct type with two members");
2879 
2880  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2881  if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
2882  return failure();
2883 
2884  state.addTypes(resultType);
2885  return success();
2886 }
2887 
2889  printer << ' ';
2890  printer.printOptionalAttrDict((*this)->getAttrs());
2891  printer.printOperands((*this)->getOperands());
2892  printer << " : " << getType();
2893 }
2894 
2895 //===----------------------------------------------------------------------===//
2896 // spv.LoadOp
2897 //===----------------------------------------------------------------------===//
2898 
2899 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2900  Value basePtr, MemoryAccessAttr memoryAccess,
2901  IntegerAttr alignment) {
2902  auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2903  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
2904  alignment);
2905 }
2906 
2907 ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &state) {
2908  // Parse the storage class specification
2909  spirv::StorageClass storageClass;
2911  Type elementType;
2912  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2913  parseMemoryAccessAttributes(parser, state) ||
2914  parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2915  parser.parseType(elementType)) {
2916  return failure();
2917  }
2918 
2919  auto ptrType = spirv::PointerType::get(elementType, storageClass);
2920  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2921  return failure();
2922  }
2923 
2924  state.addTypes(elementType);
2925  return success();
2926 }
2927 
2928 void spirv::LoadOp::print(OpAsmPrinter &printer) {
2929  SmallVector<StringRef, 4> elidedAttrs;
2930  StringRef sc = stringifyStorageClass(
2931  ptr().getType().cast<spirv::PointerType>().getStorageClass());
2932  printer << " \"" << sc << "\" " << ptr();
2933 
2934  printMemoryAccessAttribute(*this, printer, elidedAttrs);
2935 
2936  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
2937  printer << " : " << getType();
2938 }
2939 
2941  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2942  // type with fixed size; i.e., it cannot be, nor include, any
2943  // OpTypeRuntimeArray types."
2944  if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) {
2945  return failure();
2946  }
2947  return verifyMemoryAccessAttribute(*this);
2948 }
2949 
2950 //===----------------------------------------------------------------------===//
2951 // spv.mlir.loop
2952 //===----------------------------------------------------------------------===//
2953 
2954 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2955  state.addAttribute("loop_control",
2956  builder.getI32IntegerAttr(
2957  static_cast<uint32_t>(spirv::LoopControl::None)));
2958  state.addRegion();
2959 }
2960 
2961 ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) {
2962  if (parseControlAttribute<spirv::LoopControl>(parser, state))
2963  return failure();
2964  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2965  /*argTypes=*/{});
2966 }
2967 
2968 void spirv::LoopOp::print(OpAsmPrinter &printer) {
2969  auto control = loop_control();
2970  if (control != spirv::LoopControl::None)
2971  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2972  printer << ' ';
2973  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2974  /*printBlockTerminators=*/true);
2975 }
2976 
2977 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2978 /// given `dstBlock`.
2979 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2980  // Check that there is only one op in the `srcBlock`.
2981  if (!llvm::hasSingleElement(srcBlock))
2982  return false;
2983 
2984  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2985  return branchOp && branchOp.getSuccessor() == &dstBlock;
2986 }
2987 
2988 LogicalResult spirv::LoopOp::verifyRegions() {
2989  auto *op = getOperation();
2990 
2991  // We need to verify that the blocks follow the following layout:
2992  //
2993  // +-------------+
2994  // | entry block |
2995  // +-------------+
2996  // |
2997  // v
2998  // +-------------+
2999  // | loop header | <-----+
3000  // +-------------+ |
3001  // |
3002  // ... |
3003  // \ | / |
3004  // v |
3005  // +---------------+ |
3006  // | loop continue | -----+
3007  // +---------------+
3008  //
3009  // ...
3010  // \ | /
3011  // v
3012  // +-------------+
3013  // | merge block |
3014  // +-------------+
3015 
3016  auto &region = op->getRegion(0);
3017  // Allow empty region as a degenerated case, which can come from
3018  // optimizations.
3019  if (region.empty())
3020  return success();
3021 
3022  // The last block is the merge block.
3023  Block &merge = region.back();
3024  if (!isMergeBlock(merge))
3025  return emitOpError(
3026  "last block must be the merge block with only one 'spv.mlir.merge' op");
3027 
3028  if (std::next(region.begin()) == region.end())
3029  return emitOpError(
3030  "must have an entry block branching to the loop header block");
3031  // The first block is the entry block.
3032  Block &entry = region.front();
3033 
3034  if (std::next(region.begin(), 2) == region.end())
3035  return emitOpError(
3036  "must have a loop header block branched from the entry block");
3037  // The second block is the loop header block.
3038  Block &header = *std::next(region.begin(), 1);
3039 
3040  if (!hasOneBranchOpTo(entry, header))
3041  return emitOpError(
3042  "entry block must only have one 'spv.Branch' op to the second block");
3043 
3044  if (std::next(region.begin(), 3) == region.end())
3045  return emitOpError(
3046  "requires a loop continue block branching to the loop header block");
3047  // The second to last block is the loop continue block.
3048  Block &cont = *std::prev(region.end(), 2);
3049 
3050  // Make sure that we have a branch from the loop continue block to the loop
3051  // header block.
3052  if (llvm::none_of(
3053  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3054  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3055  return emitOpError("second to last block must be the loop continue "
3056  "block that branches to the loop header block");
3057 
3058  // Make sure that no other blocks (except the entry and loop continue block)
3059  // branches to the loop header block.
3060  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3061  std::prev(region.end(), 2))) {
3062  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3063  if (block.getSuccessor(i) == &header) {
3064  return emitOpError("can only have the entry and loop continue "
3065  "block branching to the loop header block");
3066  }
3067  }
3068  }
3069 
3070  return success();
3071 }
3072 
3073 Block *spirv::LoopOp::getEntryBlock() {
3074  assert(!body().empty() && "op region should not be empty!");
3075  return &body().front();
3076 }
3077 
3078 Block *spirv::LoopOp::getHeaderBlock() {
3079  assert(!body().empty() && "op region should not be empty!");
3080  // The second block is the loop header block.
3081  return &*std::next(body().begin());
3082 }
3083 
3084 Block *spirv::LoopOp::getContinueBlock() {
3085  assert(!body().empty() && "op region should not be empty!");
3086  // The second to last block is the loop continue block.
3087  return &*std::prev(body().end(), 2);
3088 }
3089 
3090 Block *spirv::LoopOp::getMergeBlock() {
3091  assert(!body().empty() && "op region should not be empty!");
3092  // The last block is the loop merge block.
3093  return &body().back();
3094 }
3095 
3096 void spirv::LoopOp::addEntryAndMergeBlock() {
3097  assert(body().empty() && "entry and merge block already exist");
3098  body().push_back(new Block());
3099  auto *mergeBlock = new Block();
3100  body().push_back(mergeBlock);
3101  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3102 
3103  // Add a spv.mlir.merge op into the merge block.
3104  builder.create<spirv::MergeOp>(getLoc());
3105 }
3106 
3107 //===----------------------------------------------------------------------===//
3108 // spv.MemoryBarrierOp
3109 //===----------------------------------------------------------------------===//
3110 
3112  return verifyMemorySemantics(getOperation(), memory_semantics());
3113 }
3114 
3115 //===----------------------------------------------------------------------===//
3116 // spv.mlir.merge
3117 //===----------------------------------------------------------------------===//
3118 
3120  auto *parentOp = (*this)->getParentOp();
3121  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3122  return emitOpError(
3123  "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
3124 
3125  // TODO: This check should be done in `verifyRegions` of parent op.
3126  Block &parentLastBlock = (*this)->getParentRegion()->back();
3127  if (getOperation() != parentLastBlock.getTerminator())
3128  return emitOpError("can only be used in the last block of "
3129  "'spv.mlir.selection' or 'spv.mlir.loop'");
3130  return success();
3131 }
3132 
3133 //===----------------------------------------------------------------------===//
3134 // spv.module
3135 //===----------------------------------------------------------------------===//
3136 
3137 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3138  Optional<StringRef> name) {
3139  OpBuilder::InsertionGuard guard(builder);
3140  builder.createBlock(state.addRegion());
3141  if (name) {
3143  builder.getStringAttr(*name));
3144  }
3145 }
3146 
3147 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3148  spirv::AddressingModel addressingModel,
3149  spirv::MemoryModel memoryModel,
3150  Optional<VerCapExtAttr> vceTriple,
3151  Optional<StringRef> name) {
3152  state.addAttribute(
3153  "addressing_model",
3154  builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
3155  state.addAttribute("memory_model", builder.getI32IntegerAttr(
3156  static_cast<int32_t>(memoryModel)));
3157  OpBuilder::InsertionGuard guard(builder);
3158  builder.createBlock(state.addRegion());
3159  if (vceTriple)
3160  state.addAttribute(getVCETripleAttrName(), *vceTriple);
3161  if (name)
3163  builder.getStringAttr(*name));
3164 }
3165 
3166 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &state) {
3167  Region *body = state.addRegion();
3168 
3169  // If the name is present, parse it.
3170  StringAttr nameAttr;
3171  (void)parser.parseOptionalSymbolName(
3173 
3174  // Parse attributes
3175  spirv::AddressingModel addrModel;
3176  spirv::MemoryModel memoryModel;
3177  if (::parseEnumKeywordAttr(addrModel, parser, state) ||
3178  ::parseEnumKeywordAttr(memoryModel, parser, state))
3179  return failure();
3180 
3181  if (succeeded(parser.parseOptionalKeyword("requires"))) {
3182  spirv::VerCapExtAttr vceTriple;
3183  if (parser.parseAttribute(vceTriple,
3184  spirv::ModuleOp::getVCETripleAttrName(),
3185  state.attributes))
3186  return failure();
3187  }
3188 
3189  if (parser.parseOptionalAttrDictWithKeyword(state.attributes) ||
3190  parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
3191  return failure();
3192 
3193  // Make sure we have at least one block.
3194  if (body->empty())
3195  body->push_back(new Block());
3196 
3197  return success();
3198 }
3199 
3200 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3201  if (Optional<StringRef> name = getName()) {
3202  printer << ' ';
3203  printer.printSymbolName(*name);
3204  }
3205 
3206  SmallVector<StringRef, 2> elidedAttrs;
3207 
3208  printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " "
3209  << spirv::stringifyMemoryModel(memory_model());
3210  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3211  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3212  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3214 
3215  if (Optional<spirv::VerCapExtAttr> triple = vce_triple()) {
3216  printer << " requires " << *triple;
3217  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3218  }
3219 
3220  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3221  printer << ' ';
3222  printer.printRegion(getRegion());
3223 }
3224 
3225 LogicalResult spirv::ModuleOp::verifyRegions() {
3226  Dialect *dialect = (*this)->getDialect();
3228  entryPoints;
3229  mlir::SymbolTable table(*this);
3230 
3231  for (auto &op : *getBody()) {
3232  if (op.getDialect() != dialect)
3233  return op.emitError("'spv.module' can only contain spv.* ops");
3234 
3235  // For EntryPoint op, check that the function and execution model is not
3236  // duplicated in EntryPointOps. Also verify that the interface specified
3237  // comes from globalVariables here to make this check cheaper.
3238  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3239  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
3240  if (!funcOp) {
3241  return entryPointOp.emitError("function '")
3242  << entryPointOp.fn() << "' not found in 'spv.module'";
3243  }
3244  if (auto interface = entryPointOp.interface()) {
3245  for (Attribute varRef : interface) {
3246  auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3247  if (!varSymRef) {
3248  return entryPointOp.emitError(
3249  "expected symbol reference for interface "
3250  "specification instead of '")
3251  << varRef;
3252  }
3253  auto variableOp =
3254  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3255  if (!variableOp) {
3256  return entryPointOp.emitError("expected spv.GlobalVariable "
3257  "symbol reference instead of'")
3258  << varSymRef << "'";
3259  }
3260  }
3261  }
3262 
3263  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3264  funcOp, entryPointOp.execution_model());
3265  auto entryPtIt = entryPoints.find(key);
3266  if (entryPtIt != entryPoints.end()) {
3267  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3268  }
3269  entryPoints[key] = entryPointOp;
3270  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3271  if (funcOp.isExternal())
3272  return op.emitError("'spv.module' cannot contain external functions");
3273 
3274  // TODO: move this check to spv.func.
3275  for (auto &block : funcOp)
3276  for (auto &op : block) {
3277  if (op.getDialect() != dialect)
3278  return op.emitError(
3279  "functions in 'spv.module' can only contain spv.* ops");
3280  }
3281  }
3282  }
3283 
3284  return success();
3285 }
3286 
3287 //===----------------------------------------------------------------------===//
3288 // spv.mlir.referenceof
3289 //===----------------------------------------------------------------------===//
3290 
3292  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3293  (*this)->getParentOp(), spec_constAttr());
3294  Type constType;
3295 
3296  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3297  if (specConstOp)
3298  constType = specConstOp.default_value().getType();
3299 
3300  auto specConstCompositeOp =
3301  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3302  if (specConstCompositeOp)
3303  constType = specConstCompositeOp.type();
3304 
3305  if (!specConstOp && !specConstCompositeOp)
3306  return emitOpError(
3307  "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
3308 
3309  if (reference().getType() != constType)
3310  return emitOpError("result type mismatch with the referenced "
3311  "specialization constant's type");
3312 
3313  return success();
3314 }
3315 
3316 //===----------------------------------------------------------------------===//
3317 // spv.Return
3318 //===----------------------------------------------------------------------===//
3319 
3321  // Verification is performed in spv.func op.
3322  return success();
3323 }
3324 
3325 //===----------------------------------------------------------------------===//
3326 // spv.ReturnValue
3327 //===----------------------------------------------------------------------===//
3328 
3330  // Verification is performed in spv.func op.
3331  return success();
3332 }
3333 
3334 //===----------------------------------------------------------------------===//
3335 // spv.Select
3336 //===----------------------------------------------------------------------===//
3337 
3339  if (auto conditionTy = condition().getType().dyn_cast<VectorType>()) {
3340  auto resultVectorTy = result().getType().dyn_cast<VectorType>();
3341  if (!resultVectorTy) {
3342  return emitOpError("result expected to be of vector type when "
3343  "condition is of vector type");
3344  }
3345  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3346  return emitOpError("result should have the same number of elements as "
3347  "the condition when condition is of vector type");
3348  }
3349  }
3350  return success();
3351 }
3352 
3353 //===----------------------------------------------------------------------===//
3354 // spv.mlir.selection
3355 //===----------------------------------------------------------------------===//
3356 
3357 ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3358  OperationState &state) {
3359  if (parseControlAttribute<spirv::SelectionControl>(parser, state))
3360  return failure();
3361  return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
3362  /*argTypes=*/{});
3363 }
3364 
3365 void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3366  auto control = selection_control();
3367  if (control != spirv::SelectionControl::None)
3368  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3369  printer << ' ';
3370  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3371  /*printBlockTerminators=*/true);
3372 }
3373 
3374 LogicalResult spirv::SelectionOp::verifyRegions() {
3375  auto *op = getOperation();
3376 
3377  // We need to verify that the blocks follow the following layout:
3378  //
3379  // +--------------+
3380  // | header block |
3381  // +--------------+
3382  // / | \
3383  // ...
3384  //
3385  //
3386  // +---------+ +---------+ +---------+
3387  // | case #0 | | case #1 | | case #2 | ...
3388  // +---------+ +---------+ +---------+
3389  //
3390  //
3391  // ...
3392  // \ | /
3393  // v
3394  // +-------------+
3395  // | merge block |
3396  // +-------------+
3397 
3398  auto &region = op->getRegion(0);
3399  // Allow empty region as a degenerated case, which can come from
3400  // optimizations.
3401  if (region.empty())
3402  return success();
3403 
3404  // The last block is the merge block.
3405  if (!isMergeBlock(region.back()))
3406  return emitOpError(
3407  "last block must be the merge block with only one 'spv.mlir.merge' op");
3408 
3409  if (std::next(region.begin()) == region.end())
3410  return emitOpError("must have a selection header block");
3411 
3412  return success();
3413 }
3414 
3415 Block *spirv::SelectionOp::getHeaderBlock() {
3416  assert(!body().empty() && "op region should not be empty!");
3417  // The first block is the loop header block.
3418  return &body().front();
3419 }
3420 
3421 Block *spirv::SelectionOp::getMergeBlock() {
3422  assert(!body().empty() && "op region should not be empty!");
3423  // The last block is the loop merge block.
3424  return &body().back();
3425 }
3426 
3427 void spirv::SelectionOp::addMergeBlock() {
3428  assert(body().empty() && "entry and merge block already exist");
3429  auto *mergeBlock = new Block();
3430  body().push_back(mergeBlock);
3431  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3432 
3433  // Add a spv.mlir.merge op into the merge block.
3434  builder.create<spirv::MergeOp>(getLoc());
3435 }
3436 
3437 spirv::SelectionOp spirv::SelectionOp::createIfThen(
3438  Location loc, Value condition,
3439  function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3440  auto selectionOp =
3441  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3442 
3443  selectionOp.addMergeBlock();
3444  Block *mergeBlock = selectionOp.getMergeBlock();
3445  Block *thenBlock = nullptr;
3446 
3447  // Build the "then" block.
3448  {
3449  OpBuilder::InsertionGuard guard(builder);
3450  thenBlock = builder.createBlock(mergeBlock);
3451  thenBody(builder);
3452  builder.create<spirv::BranchOp>(loc, mergeBlock);
3453  }
3454 
3455  // Build the header block.
3456  {
3457  OpBuilder::InsertionGuard guard(builder);
3458  builder.createBlock(thenBlock);
3459  builder.create<spirv::BranchConditionalOp>(
3460  loc, condition, thenBlock,
3461  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3462  /*falseArguments=*/ArrayRef<Value>());
3463  }
3464 
3465  return selectionOp;
3466 }
3467 
3468 //===----------------------------------------------------------------------===//
3469 // spv.SpecConstant
3470 //===----------------------------------------------------------------------===//
3471 
3472 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3473  OperationState &state) {
3474  StringAttr nameAttr;
3475  Attribute valueAttr;
3476 
3477  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3478  state.attributes))
3479  return failure();
3480 
3481  // Parse optional spec_id.
3483  IntegerAttr specIdAttr;
3484  if (parser.parseLParen() ||
3485  parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
3486  parser.parseRParen())
3487  return failure();
3488  }
3489 
3490  if (parser.parseEqual() ||
3491  parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
3492  return failure();
3493 
3494  return success();
3495 }
3496 
3498  printer << ' ';
3499  printer.printSymbolName(sym_name());
3500  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3501  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3502  printer << " = " << default_value();
3503 }
3504 
3506  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3507  if (specID.getValue().isNegative())
3508  return emitOpError("SpecId cannot be negative");
3509 
3510  auto value = default_value();
3511  if (value.isa<IntegerAttr, FloatAttr>()) {
3512  // Make sure bitwidth is allowed.
3513  if (!value.getType().isa<spirv::SPIRVType>())
3514  return emitOpError("default value bitwidth disallowed");
3515  return success();
3516  }
3517  return emitOpError(
3518  "default value can only be a bool, integer, or float scalar");
3519 }
3520 
3521 //===----------------------------------------------------------------------===//
3522 // spv.StoreOp
3523 //===----------------------------------------------------------------------===//
3524 
3525 ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &state) {
3526  // Parse the storage class specification
3527  spirv::StorageClass storageClass;
3529  auto loc = parser.getCurrentLocation();
3530  Type elementType;
3531  if (parseEnumStrAttr(storageClass, parser) ||
3532  parser.parseOperandList(operandInfo, 2) ||
3533  parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3534  parser.parseType(elementType)) {
3535  return failure();
3536  }
3537 
3538  auto ptrType = spirv::PointerType::get(elementType, storageClass);
3539  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3540  state.operands)) {
3541  return failure();
3542  }
3543  return success();
3544 }
3545 
3546 void spirv::StoreOp::print(OpAsmPrinter &printer) {
3547  SmallVector<StringRef, 4> elidedAttrs;
3548  StringRef sc = stringifyStorageClass(
3549  ptr().getType().cast<spirv::PointerType>().getStorageClass());
3550  printer << " \"" << sc << "\" " << ptr() << ", " << value();
3551 
3552  printMemoryAccessAttribute(*this, printer, elidedAttrs);
3553 
3554  printer << " : " << value().getType();
3555  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3556 }
3557 
3559  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3560  // OpTypePointer whose Type operand is the same as the type of Object."
3561  if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value())))
3562  return failure();
3563  return verifyMemoryAccessAttribute(*this);
3564 }
3565 
3566 //===----------------------------------------------------------------------===//
3567 // spv.Unreachable
3568 //===----------------------------------------------------------------------===//
3569 
3571  auto *block = (*this)->getBlock();
3572  // Fast track: if this is in entry block, its invalid. Otherwise, if no
3573  // predecessors, it's valid.
3574  if (block->isEntryBlock())
3575  return emitOpError("cannot be used in reachable block");
3576  if (block->hasNoPredecessors())
3577  return success();
3578 
3579  // TODO: further verification needs to analyze reachability from
3580  // the entry block.
3581 
3582  return success();
3583 }
3584 
3585 //===----------------------------------------------------------------------===//
3586 // spv.Variable
3587 //===----------------------------------------------------------------------===//
3588 
3589 ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3590  OperationState &state) {
3591  // Parse optional initializer
3593  if (succeeded(parser.parseOptionalKeyword("init"))) {
3594  initInfo = OpAsmParser::UnresolvedOperand();
3595  if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3596  parser.parseRParen())
3597  return failure();
3598  }
3599 
3600  if (parseVariableDecorations(parser, state)) {
3601  return failure();
3602  }
3603 
3604  // Parse result pointer type
3605  Type type;
3606  if (parser.parseColon())
3607  return failure();
3608  auto loc = parser.getCurrentLocation();
3609  if (parser.parseType(type))
3610  return failure();
3611 
3612  auto ptrType = type.dyn_cast<spirv::PointerType>();
3613  if (!ptrType)
3614  return parser.emitError(loc, "expected spv.ptr type");
3615  state.addTypes(ptrType);
3616 
3617  // Resolve the initializer operand
3618  if (initInfo) {
3619  if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3620  state.operands))
3621  return failure();
3622  }
3623 
3624  auto attr = parser.getBuilder().getI32IntegerAttr(
3625  llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
3626  state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3627 
3628  return success();
3629 }
3630 
3631 void spirv::VariableOp::print(OpAsmPrinter &printer) {
3632  SmallVector<StringRef, 4> elidedAttrs{
3633  spirv::attributeName<spirv::StorageClass>()};
3634  // Print optional initializer
3635  if (getNumOperands() != 0)
3636  printer << " init(" << initializer() << ")";
3637 
3638  printVariableDecorations(*this, printer, elidedAttrs);
3639  printer << " : " << getType();
3640 }
3641 
3643  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3644  // object. It cannot be Generic. It must be the same as the Storage Class
3645  // operand of the Result Type."
3646  if (storage_class() != spirv::StorageClass::Function) {
3647  return emitOpError(
3648  "can only be used to model function-level variables. Use "
3649  "spv.GlobalVariable for module-level variables.");
3650  }
3651 
3652  auto pointerType = pointer().getType().cast<spirv::PointerType>();
3653  if (storage_class() != pointerType.getStorageClass())
3654  return emitOpError(
3655  "storage class must match result pointer's storage class");
3656 
3657  if (getNumOperands() != 0) {
3658  // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3659  // a global (module scope) OpVariable instruction".
3660  auto *initOp = getOperand(0).getDefiningOp();
3661  if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3662  spirv::ReferenceOfOp, // for spec constant
3663  spirv::AddressOfOp>(initOp))
3664  return emitOpError("initializer must be the result of a "
3665  "constant or spv.GlobalVariable op");
3666  }
3667 
3668  // TODO: generate these strings using ODS.
3669  auto *op = getOperation();
3670  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3671  stringifyDecoration(spirv::Decoration::DescriptorSet));
3672  auto bindingName = llvm::convertToSnakeFromCamelCase(
3673  stringifyDecoration(spirv::Decoration::Binding));
3674  auto builtInName = llvm::convertToSnakeFromCamelCase(
3675  stringifyDecoration(spirv::Decoration::BuiltIn));
3676 
3677  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3678  if (op->getAttr(attr))
3679  return emitOpError("cannot have '")
3680  << attr << "' attribute (only allowed in spv.GlobalVariable)";
3681  }
3682 
3683  return success();
3684 }
3685 
3686 //===----------------------------------------------------------------------===//
3687 // spv.VectorShuffle
3688 //===----------------------------------------------------------------------===//
3689 
3691  VectorType resultType = getType().cast<VectorType>();
3692 
3693  size_t numResultElements = resultType.getNumElements();
3694  if (numResultElements != components().size())
3695  return emitOpError("result type element count (")
3696  << numResultElements
3697  << ") mismatch with the number of component selectors ("
3698  << components().size() << ")";
3699 
3700  size_t totalSrcElements =
3701  vector1().getType().cast<VectorType>().getNumElements() +
3702  vector2().getType().cast<VectorType>().getNumElements();
3703 
3704  for (const auto &selector : components().getAsValueRange<IntegerAttr>()) {
3705  uint32_t index = selector.getZExtValue();
3706  if (index >= totalSrcElements &&
3707  index != std::numeric_limits<uint32_t>().max())
3708  return emitOpError("component selector ")
3709  << index << " out of range: expected to be in [0, "
3710  << totalSrcElements << ") or 0xffffffff";
3711  }
3712  return success();
3713 }
3714 
3715 //===----------------------------------------------------------------------===//
3716 // spv.CooperativeMatrixLoadNV
3717 //===----------------------------------------------------------------------===//
3718 
3719 ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
3720  OperationState &state) {
3722  Type strideType = parser.getBuilder().getIntegerType(32);
3723  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3724  Type ptrType;
3725  Type elementType;
3726  if (parser.parseOperandList(operandInfo, 3) ||
3727  parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3728  parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3729  return failure();
3730  }
3731  if (parser.resolveOperands(operandInfo,
3732  {ptrType, strideType, columnMajorType},
3733  parser.getNameLoc(), state.operands)) {
3734  return failure();
3735  }
3736 
3737  state.addTypes(elementType);
3738  return success();
3739 }
3740 
3742  printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
3743  // Print optional memory access attribute.
3744  if (auto memAccess = memory_access())
3745  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3746  printer << " : " << pointer().getType() << " as " << getType();
3747 }
3748 
3750  Type coopMatrix) {
3751  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3752  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3753  return op->emitError(
3754  "Pointer must point to a scalar or vector type but provided ")
3755  << pointeeType;
3756  spirv::StorageClass storage =
3757  pointer.cast<spirv::PointerType>().getStorageClass();
3758  if (storage != spirv::StorageClass::Workgroup &&
3759  storage != spirv::StorageClass::StorageBuffer &&
3760  storage != spirv::StorageClass::PhysicalStorageBuffer)
3761  return op->emitError(
3762  "Pointer storage class must be Workgroup, StorageBuffer or "
3763  "PhysicalStorageBufferEXT but provided ")
3764  << stringifyStorageClass(storage);
3765  return success();
3766 }
3767 
3769  return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3770  result().getType());
3771 }
3772 
3773 //===----------------------------------------------------------------------===//
3774 // spv.CooperativeMatrixStoreNV
3775 //===----------------------------------------------------------------------===//
3776 
3777 ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
3778  OperationState &state) {
3780  Type strideType = parser.getBuilder().getIntegerType(32);
3781  Type columnMajorType = parser.getBuilder().getIntegerType(1);
3782  Type ptrType;
3783  Type elementType;
3784  if (parser.parseOperandList(operandInfo, 4) ||
3785  parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3786  parser.parseType(ptrType) || parser.parseComma() ||
3787  parser.parseType(elementType)) {
3788  return failure();
3789  }
3790  if (parser.resolveOperands(
3791  operandInfo, {ptrType, elementType, strideType, columnMajorType},
3792  parser.getNameLoc(), state.operands)) {
3793  return failure();
3794  }
3795 
3796  return success();
3797 }
3798 
3800  printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
3801  << columnmajor();
3802  // Print optional memory access attribute.
3803  if (auto memAccess = memory_access())
3804  printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3805  printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
3806 }
3807 
3809  return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3810  object().getType());
3811 }
3812 
3813 //===----------------------------------------------------------------------===//
3814 // spv.CooperativeMatrixMulAddNV
3815 //===----------------------------------------------------------------------===//
3816 
3817 static LogicalResult
3818 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3819  if (op.c().getType() != op.result().getType())
3820  return op.emitOpError("result and third operand must have the same type");
3821  auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3822  auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3823  auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3824  auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3825  if (typeA.getRows() != typeR.getRows() ||
3826  typeA.getColumns() != typeB.getRows() ||
3827  typeB.getColumns() != typeR.getColumns())
3828  return op.emitOpError("matrix size must match");
3829  if (typeR.getScope() != typeA.getScope() ||
3830  typeR.getScope() != typeB.getScope() ||
3831  typeR.getScope() != typeC.getScope())
3832  return op.emitOpError("matrix scope must match");
3833  if (typeA.getElementType() != typeB.getElementType() ||
3834  typeR.getElementType() != typeC.getElementType())
3835  return op.emitOpError("matrix element type must match");
3836  return success();
3837 }
3838 
3840  return verifyCoopMatrixMulAdd(*this);
3841 }
3842 
3843 //===----------------------------------------------------------------------===//
3844 // spv.MatrixTimesScalar
3845 //===----------------------------------------------------------------------===//
3846 
3848  // We already checked that result and matrix are both of matrix type in the
3849  // auto-generated verify method.
3850 
3851  auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
3852  auto resultMatrix = result().getType().cast<spirv::MatrixType>();
3853 
3854  // Check that the scalar type is the same as the matrix element type.
3855  if (scalar().getType() != inputMatrix.getElementType())
3856  return emitError("input matrix components' type and scaling value must "
3857  "have the same type");
3858 
3859  // Note that the next three checks could be done using the AllTypesMatch
3860  // trait in the Op definition file but it generates a vague error message.
3861 
3862  // Check that the input and result matrices have the same columns' count
3863  if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3864  return emitError("input and result matrices must have the same "
3865  "number of columns");
3866 
3867  // Check that the input and result matrices' have the same rows count
3868  if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3869  return emitError("input and result matrices' columns must have "
3870  "the same size");
3871 
3872  // Check that the input and result matrices' have the same component type
3873  if (inputMatrix.getElementType() != resultMatrix.getElementType())
3874  return emitError("input and result matrices' columns must have "
3875  "the same component type");
3876 
3877  return success();
3878 }
3879 
3880 //===----------------------------------------------------------------------===//
3881 // spv.CopyMemory
3882 //===----------------------------------------------------------------------===//
3883 
3885  printer << ' ';
3886 
3887  StringRef targetStorageClass = stringifyStorageClass(
3888  target().getType().cast<spirv::PointerType>().getStorageClass());
3889  printer << " \"" << targetStorageClass << "\" " << target() << ", ";
3890 
3891  StringRef sourceStorageClass = stringifyStorageClass(
3892  source().getType().cast<spirv::PointerType>().getStorageClass());
3893  printer << " \"" << sourceStorageClass << "\" " << source();
3894 
3895  SmallVector<StringRef, 4> elidedAttrs;
3896  printMemoryAccessAttribute(*this, printer, elidedAttrs);
3897  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
3898  source_memory_access(), source_alignment());
3899 
3900  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3901 
3902  Type pointeeType =
3903  target().getType().cast<spirv::PointerType>().getPointeeType();
3904  printer << " : " << pointeeType;
3905 }
3906 
3907 ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
3908  OperationState &state) {
3909  spirv::StorageClass targetStorageClass;
3910  OpAsmParser::UnresolvedOperand targetPtrInfo;
3911 
3912  spirv::StorageClass sourceStorageClass;
3913  OpAsmParser::UnresolvedOperand sourcePtrInfo;
3914 
3915  Type elementType;
3916 
3917  if (parseEnumStrAttr(targetStorageClass, parser) ||
3918  parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3919  parseEnumStrAttr(sourceStorageClass, parser) ||
3920  parser.parseOperand(sourcePtrInfo) ||
3921  parseMemoryAccessAttributes(parser, state)) {
3922  return failure();
3923  }
3924 
3925  if (!parser.parseOptionalComma()) {
3926  // Parse 2nd memory access attributes.
3927  if (parseSourceMemoryAccessAttributes(parser, state)) {
3928  return failure();
3929  }
3930  }
3931 
3932  if (parser.parseColon() || parser.parseType(elementType))
3933  return failure();
3934 
3935  if (parser.parseOptionalAttrDict(state.attributes))
3936  return failure();
3937 
3938  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3939  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3940 
3941  if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3942  parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3943  return failure();
3944  }
3945 
3946  return success();
3947 }
3948 
3950  Type targetType =
3951  target().getType().cast<spirv::PointerType>().getPointeeType();
3952 
3953  Type sourceType =
3954  source().getType().cast<spirv::PointerType>().getPointeeType();
3955 
3956  if (targetType != sourceType)
3957  return emitOpError("both operands must be pointers to the same type");
3958 
3959  if (failed(verifyMemoryAccessAttribute(*this)))
3960  return failure();
3961 
3962  // TODO - According to the spec:
3963  //
3964  // If two masks are present, the first applies to Target and cannot include
3965  // MakePointerVisible, and the second applies to Source and cannot include
3966  // MakePointerAvailable.
3967  //
3968  // Add such verification here.
3969 
3970  return verifySourceMemoryAccessAttribute(*this);
3971 }
3972 
3973 //===----------------------------------------------------------------------===//
3974 // spv.Transpose
3975 //===----------------------------------------------------------------------===//
3976 
3978  auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
3979  auto resultMatrix = result().getType().cast<spirv::MatrixType>();
3980 
3981  // Verify that the input and output matrices have correct shapes.
3982  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3983  return emitError("input matrix rows count must be equal to "
3984  "output matrix columns count");
3985 
3986  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3987  return emitError("input matrix columns count must be equal to "
3988  "output matrix rows count");
3989 
3990  // Verify that the input and output matrices have the same component type
3991  if (inputMatrix.getElementType() != resultMatrix.getElementType())
3992  return emitError("input and output matrices must have the same "
3993  "component type");
3994 
3995  return success();
3996 }
3997 
3998 //===----------------------------------------------------------------------===//
3999 // spv.MatrixTimesMatrix
4000 //===----------------------------------------------------------------------===//
4001 
4003  auto leftMatrix = leftmatrix().getType().cast<spirv::MatrixType>();
4004  auto rightMatrix = rightmatrix().getType().cast<spirv::MatrixType>();
4005  auto resultMatrix = result().getType().cast<spirv::MatrixType>();
4006 
4007  // left matrix columns' count and right matrix rows' count must be equal
4008  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4009  return emitError("left matrix columns' count must be equal to "
4010  "the right matrix rows' count");
4011 
4012  // right and result matrices columns' count must be the same
4013  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4014  return emitError(
4015  "right and result matrices must have equal columns' count");
4016 
4017  // right and result matrices component type must be the same
4018  if (rightMatrix.getElementType() != resultMatrix.getElementType())
4019  return emitError("right and result matrices' component type must"
4020  " be the same");
4021 
4022  // left and result matrices component type must be the same
4023  if (leftMatrix.getElementType() != resultMatrix.getElementType())
4024  return emitError("left and result matrices' component type"
4025  " must be the same");
4026 
4027  // left and result matrices rows count must be the same
4028  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4029  return emitError("left and result matrices must have equal rows' count");
4030 
4031  return success();
4032 }
4033 
4034 //===----------------------------------------------------------------------===//
4035 // spv.SpecConstantComposite
4036 //===----------------------------------------------------------------------===//
4037 
4038 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4039  OperationState &state) {
4040 
4041  StringAttr compositeName;
4042  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4043  state.attributes))
4044  return failure();
4045 
4046  if (parser.parseLParen())
4047  return failure();
4048 
4049  SmallVector<Attribute, 4> constituents;
4050 
4051  do {
4052  // The name of the constituent attribute isn't important
4053  const char *attrName = "spec_const";
4054  FlatSymbolRefAttr specConstRef;
4055  NamedAttrList attrs;
4056 
4057  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4058  return failure();
4059 
4060  constituents.push_back(specConstRef);
4061  } while (!parser.parseOptionalComma());
4062 
4063  if (parser.parseRParen())
4064  return failure();
4065 
4067  parser.getBuilder().getArrayAttr(constituents));
4068 
4069  Type type;
4070  if (parser.parseColonType(type))
4071  return failure();
4072 
4073  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
4074 
4075  return success();
4076 }
4077 
4079  printer << " ";
4080  printer.printSymbolName(sym_name());
4081  printer << " (";
4082  auto constituents = this->constituents().getValue();
4083 
4084  if (!constituents.empty())
4085  llvm::interleaveComma(constituents, printer);
4086 
4087  printer << ") : " << type();
4088 }
4089 
4091  auto cType = type().dyn_cast<spirv::CompositeType>();
4092  auto constituents = this->constituents().getValue();
4093 
4094  if (!cType)
4095  return emitError("result type must be a composite type, but provided ")
4096  << type();
4097 
4098  if (cType.isa<spirv::CooperativeMatrixNVType>())
4099  return emitError("unsupported composite type ") << cType;
4100  if (constituents.size() != cType.getNumElements())
4101  return emitError("has incorrect number of operands: expected ")
4102  << cType.getNumElements() << ", but provided "
4103  << constituents.size();
4104 
4105  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4106  auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4107 
4108  auto constituentSpecConstOp =
4109  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4110  (*this)->getParentOp(), constituent.getAttr()));
4111 
4112  if (constituentSpecConstOp.default_value().getType() !=
4113  cType.getElementType(index))
4114  return emitError("has incorrect types of operands: expected ")
4115  << cType.getElementType(index) << ", but provided "
4116  << constituentSpecConstOp.default_value().getType();
4117  }
4118 
4119  return success();
4120 }
4121 
4122 //===----------------------------------------------------------------------===//
4123 // spv.SpecConstantOperation
4124 //===----------------------------------------------------------------------===//
4125 
4126 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4127  OperationState &state) {
4128  Region *body = state.addRegion();
4129 
4130  if (parser.parseKeyword("wraps"))
4131  return failure();
4132 
4133  body->push_back(new Block);
4134  Block &block = body->back();
4135  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4136 
4137  if (!wrappedOp)
4138  return failure();
4139 
4140  OpBuilder builder(parser.getContext());
4141  builder.setInsertionPointToEnd(&block);
4142  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4143  state.location = wrappedOp->getLoc();
4144 
4145  state.addTypes(wrappedOp->getResult(0).getType());
4146 
4147  if (parser.parseOptionalAttrDict(state.attributes))
4148  return failure();
4149 
4150  return success();
4151 }
4152 
4154  printer << " wraps ";
4155  printer.printGenericOp(&body().front().front());
4156 }
4157 
4158 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4159  Block &block = getRegion().getBlocks().front();
4160 
4161  if (block.getOperations().size() != 2)
4162  return emitOpError("expected exactly 2 nested ops");
4163 
4164  Operation &enclosedOp = block.getOperations().front();
4165 
4167  return emitOpError("invalid enclosed op");
4168 
4169  for (auto operand : enclosedOp.getOperands())
4170  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4171  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4172  return emitOpError(
4173  "invalid operand, must be defined by a constant operation");
4174 
4175  return success();
4176 }
4177 
4178 //===----------------------------------------------------------------------===//
4179 // spv.GLSL.FrexpStruct
4180 //===----------------------------------------------------------------------===//
4181 
4183  spirv::StructType structTy = result().getType().dyn_cast<spirv::StructType>();
4184 
4185  if (structTy.getNumElements() != 2)
4186  return emitError("result type must be a struct type with two memebers");
4187 
4188  Type significandTy = structTy.getElementType(0);
4189  Type exponentTy = structTy.getElementType(1);
4190  VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4191  IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4192 
4193  Type operandTy = operand().getType();
4194  VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4195  FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4196 
4197  if (significandTy != operandTy)
4198  return emitError("member zero of the resulting struct type must be the "
4199  "same type as the operand");
4200 
4201  if (exponentVecTy) {
4202  IntegerType componentIntTy =
4203  exponentVecTy.getElementType().dyn_cast<IntegerType>();
4204  if (!componentIntTy || componentIntTy.getWidth() != 32)
4205  return emitError("member one of the resulting struct type must"
4206  "be a scalar or vector of 32 bit integer type");
4207  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4208  return emitError("member one of the resulting struct type "
4209  "must be a scalar or vector of 32 bit integer type");
4210  }
4211 
4212  // Check that the two member types have the same number of components
4213  if (operandVecTy && exponentVecTy &&
4214  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4215  return success();
4216 
4217  if (operandFTy && exponentIntTy)
4218  return success();
4219 
4220  return emitError("member one of the resulting struct type must have the same "
4221  "number of components as the operand type");
4222 }
4223 
4224 //===----------------------------------------------------------------------===//
4225 // spv.GLSL.Ldexp
4226 //===----------------------------------------------------------------------===//
4227 
4229  Type significandType = x().getType();
4230  Type exponentType = exp().getType();
4231 
4232  if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4233  return emitOpError("operands must both be scalars or vectors");
4234 
4235  auto getNumElements = [](Type type) -> unsigned {
4236  if (auto vectorType = type.dyn_cast<VectorType>())
4237  return vectorType.getNumElements();
4238  return 1;
4239  };
4240 
4241  if (getNumElements(significandType) != getNumElements(exponentType))
4242  return emitOpError("operands must have the same number of elements");
4243 
4244  return success();
4245 }
4246 
4247 //===----------------------------------------------------------------------===//
4248 // spv.ImageDrefGather
4249 //===----------------------------------------------------------------------===//
4250 
4252  VectorType resultType = result().getType().cast<VectorType>();
4253  auto sampledImageType =
4254  sampledimage().getType().cast<spirv::SampledImageType>();
4255  auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4256 
4257  if (resultType.getNumElements() != 4)
4258  return emitOpError("result type must be a vector of four components");
4259 
4260  Type elementType = resultType.getElementType();
4261  Type sampledElementType = imageType.getElementType();
4262  if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4263  return emitOpError(
4264  "the component type of result must be the same as sampled type of the "
4265  "underlying image type");
4266 
4267  spirv::Dim imageDim = imageType.getDim();
4268  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4269 
4270  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4271  imageDim != spirv::Dim::Rect)
4272  return emitOpError(
4273  "the Dim operand of the underlying image type must be 2D, Cube, or "
4274  "Rect");
4275 
4276  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4277  return emitOpError("the MS operand of the underlying image type must be 0");
4278 
4279  spirv::ImageOperandsAttr attr = imageoperandsAttr();
4280  auto operandArguments = operand_arguments();
4281 
4282  return verifyImageOperands(*this, attr, operandArguments);
4283 }
4284 
4285 //===----------------------------------------------------------------------===//
4286 // spv.ShiftLeftLogicalOp
4287 //===----------------------------------------------------------------------===//
4288 
4290  return verifyShiftOp(*this);
4291 }
4292 
4293 //===----------------------------------------------------------------------===//
4294 // spv.ShiftRightArithmeticOp
4295 //===----------------------------------------------------------------------===//
4296 
4298  return verifyShiftOp(*this);
4299 }
4300 
4301 //===----------------------------------------------------------------------===//
4302 // spv.ShiftRightLogicalOp
4303 //===----------------------------------------------------------------------===//
4304 
4306  return verifyShiftOp(*this);
4307 }
4308 
4309 //===----------------------------------------------------------------------===//
4310 // spv.ImageQuerySize
4311 //===----------------------------------------------------------------------===//
4312 
4314  spirv::ImageType imageType = image().getType().cast<spirv::ImageType>();
4315  Type resultType = result().getType();
4316 
4317  spirv::Dim dim = imageType.getDim();
4318  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4319  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4320  switch (dim) {
4321  case spirv::Dim::Dim1D:
4322  case spirv::Dim::Dim2D:
4323  case spirv::Dim::Dim3D:
4324  case spirv::Dim::Cube:
4325  if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4326  samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4327  samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4328  return emitError(
4329  "if Dim is 1D, 2D, 3D, or Cube, "
4330  "it must also have either an MS of 1 or a Sampled of 0 or 2");
4331  break;
4332  case spirv::Dim::Buffer:
4333  case spirv::Dim::Rect:
4334  break;
4335  default:
4336  return emitError("the Dim operand of the image type must "
4337  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4338  }
4339 
4340  unsigned componentNumber = 0;
4341  switch (dim) {
4342  case spirv::Dim::Dim1D:
4343  case spirv::Dim::Buffer:
4344  componentNumber = 1;
4345  break;
4346  case spirv::Dim::Dim2D:
4347  case spirv::Dim::Cube:
4348  case spirv::Dim::Rect:
4349  componentNumber = 2;
4350  break;
4351  case spirv::Dim::Dim3D:
4352  componentNumber = 3;
4353  break;
4354  default:
4355  break;
4356  }
4357 
4358  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4359  componentNumber += 1;
4360 
4361  unsigned resultComponentNumber = 1;
4362  if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4363  resultComponentNumber = resultVectorType.getNumElements();
4364 
4365  if (componentNumber != resultComponentNumber)
4366  return emitError("expected the result to have ")
4367  << componentNumber << " component(s), but found "
4368  << resultComponentNumber << " component(s)";
4369 
4370  return success();
4371 }
4372 
4373 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4374  OpAsmParser &parser,
4375  OperationState &state) {
4378  Type type;
4379  auto loc = parser.getCurrentLocation();
4380  SmallVector<Type, 4> indicesTypes;
4381 
4382  if (parser.parseOperand(ptrInfo) ||
4383  parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4384  parser.parseColonType(type) ||
4385  parser.resolveOperand(ptrInfo, type, state.operands))
4386  return failure();
4387 
4388  // Check that the provided indices list is not empty before parsing their
4389  // type list.
4390  if (indicesInfo.empty())
4391  return emitError(state.location) << opName << " expected element";
4392 
4393  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4394  return failure();
4395 
4396  // Check that the indices types list is not empty and that it has a one-to-one
4397  // mapping to the provided indices.
4398  if (indicesTypes.size() != indicesInfo.size())
4399  return emitError(state.location)
4400  << opName
4401  << " indices types' count must be equal to indices info count";
4402 
4403  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4404  return failure();
4405 
4406  auto resultType = getElementPtrType(
4407  type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4408  if (!resultType)
4409  return failure();
4410 
4411  state.addTypes(resultType);
4412  return success();
4413 }
4414 
4415 template <typename Op>
4416 static auto concatElemAndIndices(Op op) {
4417  SmallVector<Value> ret(op.indices().size() + 1);
4418  ret[0] = op.element();
4419  llvm::copy(op.indices(), ret.begin() + 1);
4420  return ret;
4421 }
4422 
4423 //===----------------------------------------------------------------------===//
4424 // spv.InBoundsPtrAccessChainOp
4425 //===----------------------------------------------------------------------===//
4426 
4427 void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4428  OperationState &state,
4429  Value basePtr, Value element,
4430  ValueRange indices) {
4431  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4432  assert(type && "Unable to deduce return type based on basePtr and indices");
4433  build(builder, state, type, basePtr, element, indices);
4434 }
4435 
4436 ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4437  OperationState &state) {
4439  spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
4440 }
4441 
4443  printAccessChain(*this, concatElemAndIndices(*this), printer);
4444 }
4445 
4447  return verifyAccessChain(*this, indices());
4448 }
4449 
4450 //===----------------------------------------------------------------------===//
4451 // spv.PtrAccessChainOp
4452 //===----------------------------------------------------------------------===//
4453 
4454 void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4455  Value basePtr, Value element,
4456  ValueRange indices) {
4457  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4458  assert(type && "Unable to deduce return type based on basePtr and indices");
4459  build(builder, state, type, basePtr, element, indices);
4460 }
4461 
4462 ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4463  OperationState &state) {
4464  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4465  parser, state);
4466 }
4467 
4469  printAccessChain(*this, concatElemAndIndices(*this), printer);
4470 }
4471 
4473  return verifyAccessChain(*this, indices());
4474 }
4475 
4476 //===----------------------------------------------------------------------===//
4477 // spv.VectorTimesScalarOp
4478 //===----------------------------------------------------------------------===//
4479 
4481  if (vector().getType() != getType())
4482  return emitOpError("vector operand and result type mismatch");
4483  auto scalarType = getType().cast<VectorType>().getElementType();
4484  if (scalar().getType() != scalarType)
4485  return emitOpError("scalar operand and result element type match");
4486  return success();
4487 }
4488 
4489 // TableGen'erated operation interfaces for querying versions, extensions, and
4490 // capabilities.
4491 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4492 
4493 // TablenGen'erated operation definitions.
4494 #define GET_OP_CLASSES
4495 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4496 
4497 namespace mlir {
4498 namespace spirv {
4499 // TableGen'erated operation availability interface implementations.
4500 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4501 } // namespace spirv
4502 } // namespace mlir
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printFunctionSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spv.mlir.merge op.
Definition: SPIRVOps.cpp:755
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
static ParseResult parseImageOperands(OpAsmParser &parser, spirv::ImageOperandsAttr &attr)
Definition: SPIRVOps.cpp:368
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:282
static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, Value value)
Definition: SPIRVOps.cpp:959
Block * getSuccessor(unsigned i)
Definition: Block.cpp:240
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:114
iterator begin()
Definition: Block.h:134
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:130
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:4373
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:510
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:550
Operation & back()
Definition: Block.h:143
static std::string bindingName()
Returns the string name of the Binding decoration.
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:302
static constexpr const char kMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:38
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:589
unsigned getNumElements() const
Return the number of elements of the type.
Definition: SPIRVTypes.cpp:122
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:311
Block represents an ordered list of Operations.
Definition: Block.h:29
A symbol reference with a reference path containing a single element.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:363
Value getOperand(unsigned idx)
Definition: Operation.h:274
static constexpr const char kSourceMemoryAccessAttrName[]
Definition: SPIRVOps.cpp:39
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:382
OpListType & getOperations()
Definition: Block.h:128
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyAtomicUpdateOp(Operation *op)
Definition: SPIRVOps.cpp:823
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments, to the list of operation attributes in result.
bool isa() const
Definition: Attributes.h:109
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getNumOperands()
Definition: Operation.h:270
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
static constexpr const char kDefaultValueAttrName[]
Definition: SPIRVOps.cpp:47
static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:601
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
operand_type_range getOperandTypes()
Definition: Operation.h:321
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
NoneType getNoneType()
Definition: Builders.cpp:75
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:222
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:574
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:97
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
This is the representation of an operand reference.
Operation & front()
Definition: Block.h:144
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:638
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
Definition: SPIRVOps.cpp:172
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:1143
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spv.Branch to the given dstBlock.
Definition: SPIRVOps.cpp:2979
static bool isNestedInFunctionOpInterface(Operation *op)
Returns true if the given op is a function-like op or nested in a function-like op without a module-l...
Definition: SPIRVOps.cpp:118
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix)
Definition: SPIRVOps.cpp:3749
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1078
static StringRef stringifyTypeName()
static constexpr const char kControl[]
Definition: SPIRVOps.cpp:46
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
static constexpr const bool value
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attrName&#39;...
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:795
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
StringRef stringifyTypeName< FloatType >()
Definition: SPIRVOps.cpp:817
static constexpr const char kInterfaceAttrName[]
Definition: SPIRVOps.cpp:54
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
static constexpr const char kSpecIdAttrName[]
Definition: SPIRVOps.cpp:57
static constexpr const char kTypeAttrName[]
Definition: SPIRVOps.cpp:58
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:337
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
bool empty()
Definition: Region.h:60
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
static unsigned getBitWidth(Type type)
Definition: SPIRVOps.cpp:666
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:639
void addOperands(ValueRange newOperands)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:1822
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:1136
Type getFunctionType(Builder &builder, ArrayRef< OpAsmParser::Argument > argAttrs, ArrayRef< Type > resultTypes)
Get a function type corresponding to an array of arguments (which have types) and a set of result typ...
U dyn_cast() const
Definition: Types.h:256
iterator end()
Definition: Block.h:135
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
Definition: SPIRVOps.cpp:466
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static constexpr const char kIndicesAttrName[]
Definition: SPIRVOps.cpp:52
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static constexpr const char kClusterSize[]
Definition: SPIRVOps.cpp:45
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:345
Type getElementType() const
Definition: SPIRVTypes.cpp:62
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:526
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:172
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, Optional< spirv::MemoryAccess > memoryAccessAtrrValue=None, Optional< uint32_t > alignmentAttrValue=None)
Definition: SPIRVOps.cpp:340
static constexpr const char kAlignmentAttrName[]
Definition: SPIRVOps.cpp:41
Type getElementType() const
Returns the elements&#39; type (i.e, single element type).
static constexpr const char kCallee[]
Definition: SPIRVOps.cpp:44
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual ParseResult parseRParen()=0
Parse a ) token.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
Block & back()
Definition: Region.h:64
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, Value lhs, Value rhs)
Definition: SPIRVOps.cpp:947
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation...
static constexpr const char kEqualSemanticsAttrName[]
Definition: SPIRVOps.cpp:49
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:215
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:32
static constexpr const char kValuesAttrName[]
Definition: SPIRVOps.cpp:61
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static WalkResult advance()
Definition: Visitors.h:51
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:882
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
This class models how operands are forwarded to block arguments in control flow.
static constexpr const char kMemoryScopeAttrName[]
Definition: SPIRVOps.cpp:55
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
Definition: SPIRVOps.cpp:899
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
static bool isDirectInModuleLikeOp(Operation *op)
Returns true if the given op is an module-like op that maintains a symbol table.
Definition: SPIRVOps.cpp:130
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:973
static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef< Ty > enumValues, function_ref< StringRef(Ty)> stringifyFn)
Definition: SPIRVOps.cpp:155
virtual ParseResult parseRSquare()=0
Parse a ] token.
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue)
Definition: SPIRVOps.cpp:766
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:134
static constexpr const char kFnNameAttrName[]
Definition: SPIRVOps.cpp:50
static constexpr const char kSemanticsAttrName[]
Definition: SPIRVOps.cpp:56
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:391
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access attributes attached to a memory access operand/pointer.
Definition: SPIRVOps.cpp:251
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
void addSuccessors(Block *successor)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:68
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
Definition: SPIRVOps.cpp:230
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
static constexpr const char kInitializerAttrName[]
Definition: SPIRVOps.cpp:53
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:993
NamedAttrList attributes
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:66
static constexpr const char kSourceAlignmentAttrName[]
Definition: SPIRVOps.cpp:42
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Region * addRegion()
Create a region that should be attached to the operation.
Type getElementType(unsigned) const
Definition: SPIRVTypes.cpp:995
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:392
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition: SPIRVOps.cpp:419
static constexpr const char kExecutionScopeAttrName[]
Definition: SPIRVOps.cpp:48
Type getType() const
Return the type of this value.
Definition: Value.h:118
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions...
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:108
unsigned getNumSuccessors()
Definition: Block.cpp:236
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
Definition: SPIRVOps.cpp:846
static constexpr const char kCompositeSpecConstituentsName[]
Definition: SPIRVOps.cpp:62
unsigned getRows() const
return the number of rows of the matrix.
Definition: SPIRVTypes.cpp:234
U dyn_cast() const
Definition: Attributes.h:124
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
type_range getTypes() const
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:937
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
virtual ParseResult parseType(Type &result)=0
Parse a type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:645
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
This class implements the operand iterators for the Operation class.
This provides public APIs that all operations should have.
static Type getUnaryOpResultType(Type operandType)
Result of a logical op must be a scalar or vector of boolean type.
Definition: SPIRVOps.cpp:929
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:374
static LogicalResult verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)
Definition: SPIRVOps.cpp:3818
static constexpr const char kBranchWeightAttrName[]
Definition: SPIRVOps.cpp:43
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
static auto concatElemAndIndices(Op op)
Definition: SPIRVOps.cpp:4416
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:328
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, Optional< spirv::MemoryAccess > memoryAccessAtrrValue=None, Optional< uint32_t > alignmentAttrValue=None)
Definition: SPIRVOps.cpp:310
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
static constexpr const char kGroupOperationAttrName[]
Definition: SPIRVOps.cpp:51
virtual ParseResult parseEqual()=0
Parse a = token.
Scope getScope() const
Return the scope of the cooperative matrix.
Definition: SPIRVTypes.cpp:232
unsigned getNumColumns() const
Returns the number of columns.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:67
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
unsigned getNumRows() const
Returns the number of rows.
StringRef stringifyTypeName< IntegerType >()
Definition: SPIRVOps.cpp:812
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
Definition: ParserUtils.h:27
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
bool isa() const
Definition: Types.h:246
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:66
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, spirv::ImageOperandsAttr attr)
Definition: SPIRVOps.cpp:383
static constexpr const char kValueAttrName[]
Definition: SPIRVOps.cpp:60
This class helps build Operations.
Definition: Builders.h:184
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
This class provides an abstraction over the different types of ranges over Values.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:341
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success. ...
static constexpr const char kUnequalSemanticsAttrName[]
Definition: SPIRVOps.cpp:59
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)=0
Parse an optional -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attr...
unsigned getColumns() const
return the number of columns of the matrix.
Definition: SPIRVTypes.cpp:236
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Square brackets surrounding zero or more operands.
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Definition: SPIRVOps.cpp:1088
U cast() const
Definition: Types.h:262
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp)
Definition: SPIRVOps.cpp:1175
SmallVector< Type, 4 > types
Types of the results of this operation.
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:246