MLIR  20.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 #include <type_traits>
42 
43 using namespace mlir;
44 using namespace mlir::spirv::AttrNames;
45 
46 //===----------------------------------------------------------------------===//
47 // Common utility functions
48 //===----------------------------------------------------------------------===//
49 
50 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
51  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
52  if (!constOp) {
53  return failure();
54  }
55  auto valueAttr = constOp.getValue();
56  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
57  if (!integerValueAttr) {
58  return failure();
59  }
60 
61  if (integerValueAttr.getType().isSignlessInteger())
62  value = integerValueAttr.getInt();
63  else
64  value = integerValueAttr.getSInt();
65 
66  return success();
67 }
68 
69 LogicalResult
71  spirv::MemorySemantics memorySemantics) {
72  // According to the SPIR-V specification:
73  // "Despite being a mask and allowing multiple bits to be combined, it is
74  // invalid for more than one of these four bits to be set: Acquire, Release,
75  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
76  // Release semantics is done by setting the AcquireRelease bit, not by setting
77  // two bits."
78  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
79  spirv::MemorySemantics::Release |
80  spirv::MemorySemantics::AcquireRelease |
81  spirv::MemorySemantics::SequentiallyConsistent;
82 
83  auto bitCount =
84  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
85  if (bitCount > 1) {
86  return op->emitError(
87  "expected at most one of these four memory constraints "
88  "to be set: `Acquire`, `Release`,"
89  "`AcquireRelease` or `SequentiallyConsistent`");
90  }
91  return success();
92 }
93 
95  SmallVectorImpl<StringRef> &elidedAttrs) {
96  // Print optional descriptor binding
97  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
98  stringifyDecoration(spirv::Decoration::DescriptorSet));
99  auto bindingName = llvm::convertToSnakeFromCamelCase(
100  stringifyDecoration(spirv::Decoration::Binding));
101  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
102  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
103  if (descriptorSet && binding) {
104  elidedAttrs.push_back(descriptorSetName);
105  elidedAttrs.push_back(bindingName);
106  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
107  << ")";
108  }
109 
110  // Print BuiltIn attribute if present
111  auto builtInName = llvm::convertToSnakeFromCamelCase(
112  stringifyDecoration(spirv::Decoration::BuiltIn));
113  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
114  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
115  elidedAttrs.push_back(builtInName);
116  }
117 
118  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
119 }
120 
121 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
122  OperationState &result) {
124  Type type;
125  // If the operand list is in-between parentheses, then we have a generic form.
126  // (see the fallback in `printOneResultOp`).
127  SMLoc loc = parser.getCurrentLocation();
128  if (!parser.parseOptionalLParen()) {
129  if (parser.parseOperandList(ops) || parser.parseRParen() ||
130  parser.parseOptionalAttrDict(result.attributes) ||
131  parser.parseColon() || parser.parseType(type))
132  return failure();
133  auto fnType = llvm::dyn_cast<FunctionType>(type);
134  if (!fnType) {
135  parser.emitError(loc, "expected function type");
136  return failure();
137  }
138  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
139  return failure();
140  result.addTypes(fnType.getResults());
141  return success();
142  }
143  return failure(parser.parseOperandList(ops) ||
144  parser.parseOptionalAttrDict(result.attributes) ||
145  parser.parseColonType(type) ||
146  parser.resolveOperands(ops, type, result.operands) ||
147  parser.addTypeToList(type, result.types));
148 }
149 
151  assert(op->getNumResults() == 1 && "op should have one result");
152 
153  // If not all the operand and result types are the same, just use the
154  // generic assembly form to avoid omitting information in printing.
155  auto resultType = op->getResult(0).getType();
156  if (llvm::any_of(op->getOperandTypes(),
157  [&](Type type) { return type != resultType; })) {
158  p.printGenericOp(op, /*printOpName=*/false);
159  return;
160  }
161 
162  p << ' ';
163  p.printOperands(op->getOperands());
165  // Now we can output only one type for all operands and the result.
166  p << " : " << resultType;
167 }
168 
169 template <typename Op>
170 static LogicalResult verifyImageOperands(Op imageOp,
171  spirv::ImageOperandsAttr attr,
172  Operation::operand_range operands) {
173  if (!attr) {
174  if (operands.empty())
175  return success();
176 
177  return imageOp.emitError("the Image Operands should encode what operands "
178  "follow, as per Image Operands");
179  }
180 
181  // TODO: Add the validation rules for the following Image Operands.
182  spirv::ImageOperands noSupportOperands =
183  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
184  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
185  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
186  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
187  spirv::ImageOperands::MakeTexelAvailable |
188  spirv::ImageOperands::MakeTexelVisible |
189  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
190 
191  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
192  llvm_unreachable("unimplemented operands of Image Operands");
193 
194  return success();
195 }
196 
197 template <typename BlockReadWriteOpTy>
198 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
199  Value ptr, Value val) {
200  auto valType = val.getType();
201  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
202  valType = valVecTy.getElementType();
203 
204  if (valType !=
205  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
206  return op.emitOpError("mismatch in result type and pointer type");
207  }
208  return success();
209 }
210 
211 /// Walks the given type hierarchy with the given indices, potentially down
212 /// to component granularity, to select an element type. Returns null type and
213 /// emits errors with the given loc on failure.
214 static Type
216  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
217  if (indices.empty()) {
218  emitErrorFn("expected at least one index for spirv.CompositeExtract");
219  return nullptr;
220  }
221 
222  for (auto index : indices) {
223  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
224  if (cType.hasCompileTimeKnownNumElements() &&
225  (index < 0 ||
226  static_cast<uint64_t>(index) >= cType.getNumElements())) {
227  emitErrorFn("index ") << index << " out of bounds for " << type;
228  return nullptr;
229  }
230  type = cType.getElementType(index);
231  } else {
232  emitErrorFn("cannot extract from non-composite type ")
233  << type << " with index " << index;
234  return nullptr;
235  }
236  }
237  return type;
238 }
239 
240 static Type
242  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
243  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
244  if (!indicesArrayAttr) {
245  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
246  return nullptr;
247  }
248  if (indicesArrayAttr.empty()) {
249  emitErrorFn("expected at least one index for spirv.CompositeExtract");
250  return nullptr;
251  }
252 
253  SmallVector<int32_t, 2> indexVals;
254  for (auto indexAttr : indicesArrayAttr) {
255  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
256  if (!indexIntAttr) {
257  emitErrorFn("expected an 32-bit integer for index, but found '")
258  << indexAttr << "'";
259  return nullptr;
260  }
261  indexVals.push_back(indexIntAttr.getInt());
262  }
263  return getElementType(type, indexVals, emitErrorFn);
264 }
265 
266 static Type getElementType(Type type, Attribute indices, Location loc) {
267  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
268  return ::mlir::emitError(loc, err);
269  };
270  return getElementType(type, indices, errorFn);
271 }
272 
273 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
274  SMLoc loc) {
275  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
276  return parser.emitError(loc, err);
277  };
278  return getElementType(type, indices, errorFn);
279 }
280 
281 template <typename ExtendedBinaryOp>
282 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
283  auto resultType = llvm::cast<spirv::StructType>(op.getType());
284  if (resultType.getNumElements() != 2)
285  return op.emitOpError("expected result struct type containing two members");
286 
287  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
288  resultType.getElementType(0),
289  resultType.getElementType(1)}))
290  return op.emitOpError(
291  "expected all operand types and struct member types are the same");
292 
293  return success();
294 }
295 
296 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
297  OperationState &result) {
299  if (parser.parseOptionalAttrDict(result.attributes) ||
300  parser.parseOperandList(operands) || parser.parseColon())
301  return failure();
302 
303  Type resultType;
304  SMLoc loc = parser.getCurrentLocation();
305  if (parser.parseType(resultType))
306  return failure();
307 
308  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
309  if (!structType || structType.getNumElements() != 2)
310  return parser.emitError(loc, "expected spirv.struct type with two members");
311 
312  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
313  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
314  return failure();
315 
316  result.addTypes(resultType);
317  return success();
318 }
319 
321  OpAsmPrinter &printer) {
322  printer << ' ';
323  printer.printOptionalAttrDict(op->getAttrs());
324  printer.printOperands(op->getOperands());
325  printer << " : " << op->getResultTypes().front();
326 }
327 
328 static LogicalResult verifyShiftOp(Operation *op) {
329  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
330  return op->emitError("expected the same type for the first operand and "
331  "result, but provided ")
332  << op->getOperand(0).getType() << " and "
333  << op->getResult(0).getType();
334  }
335  return success();
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // spirv.mlir.addressof
340 //===----------------------------------------------------------------------===//
341 
342 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
343  spirv::GlobalVariableOp var) {
344  build(builder, state, var.getType(), SymbolRefAttr::get(var));
345 }
346 
347 LogicalResult spirv::AddressOfOp::verify() {
348  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
349  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
350  getVariableAttr()));
351  if (!varOp) {
352  return emitOpError("expected spirv.GlobalVariable symbol");
353  }
354  if (getPointer().getType() != varOp.getType()) {
355  return emitOpError(
356  "result type mismatch with the referenced global variable's type");
357  }
358  return success();
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // spirv.CompositeConstruct
363 //===----------------------------------------------------------------------===//
364 
365 LogicalResult spirv::CompositeConstructOp::verify() {
366  operand_range constituents = this->getConstituents();
367 
368  // There are 4 cases with varying verification rules:
369  // 1. Cooperative Matrices (1 constituent)
370  // 2. Structs (1 constituent for each member)
371  // 3. Arrays (1 constituent for each array element)
372  // 4. Vectors (1 constituent (sub-)element for each vector element)
373 
374  auto coopElementType =
377  [](auto coopType) { return coopType.getElementType(); })
378  .Default([](Type) { return nullptr; });
379 
380  // Case 1. -- matrices.
381  if (coopElementType) {
382  if (constituents.size() != 1)
383  return emitOpError("has incorrect number of operands: expected ")
384  << "1, but provided " << constituents.size();
385  if (coopElementType != constituents.front().getType())
386  return emitOpError("operand type mismatch: expected operand type ")
387  << coopElementType << ", but provided "
388  << constituents.front().getType();
389  return success();
390  }
391 
392  // Case 2./3./4. -- number of constituents matches the number of elements.
393  auto cType = llvm::cast<spirv::CompositeType>(getType());
394  if (constituents.size() == cType.getNumElements()) {
395  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
396  if (constituents[index].getType() != cType.getElementType(index)) {
397  return emitOpError("operand type mismatch: expected operand type ")
398  << cType.getElementType(index) << ", but provided "
399  << constituents[index].getType();
400  }
401  }
402  return success();
403  }
404 
405  // Case 4. -- check that all constituents add up tp the expected vector type.
406  auto resultType = llvm::dyn_cast<VectorType>(cType);
407  if (!resultType)
408  return emitOpError(
409  "expected to return a vector or cooperative matrix when the number of "
410  "constituents is less than what the result needs");
411 
412  SmallVector<unsigned> sizes;
413  for (Value component : constituents) {
414  if (!llvm::isa<VectorType>(component.getType()) &&
415  !component.getType().isIntOrFloat())
416  return emitOpError("operand type mismatch: expected operand to have "
417  "a scalar or vector type, but provided ")
418  << component.getType();
419 
420  Type elementType = component.getType();
421  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
422  sizes.push_back(vectorType.getNumElements());
423  elementType = vectorType.getElementType();
424  } else {
425  sizes.push_back(1);
426  }
427 
428  if (elementType != resultType.getElementType())
429  return emitOpError("operand element type mismatch: expected to be ")
430  << resultType.getElementType() << ", but provided " << elementType;
431  }
432  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
433  if (totalCount != cType.getNumElements())
434  return emitOpError("has incorrect number of operands: expected ")
435  << cType.getNumElements() << ", but provided " << totalCount;
436  return success();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // spirv.CompositeExtractOp
441 //===----------------------------------------------------------------------===//
442 
443 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
444  Value composite,
445  ArrayRef<int32_t> indices) {
446  auto indexAttr = builder.getI32ArrayAttr(indices);
447  auto elementType =
448  getElementType(composite.getType(), indexAttr, state.location);
449  if (!elementType) {
450  return;
451  }
452  build(builder, state, elementType, composite, indexAttr);
453 }
454 
456  OperationState &result) {
457  OpAsmParser::UnresolvedOperand compositeInfo;
458  Attribute indicesAttr;
459  StringRef indicesAttrName =
460  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
461  Type compositeType;
462  SMLoc attrLocation;
463 
464  if (parser.parseOperand(compositeInfo) ||
465  parser.getCurrentLocation(&attrLocation) ||
466  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
467  parser.parseColonType(compositeType) ||
468  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
469  return failure();
470  }
471 
472  Type resultType =
473  getElementType(compositeType, indicesAttr, parser, attrLocation);
474  if (!resultType) {
475  return failure();
476  }
477  result.addTypes(resultType);
478  return success();
479 }
480 
482  printer << ' ' << getComposite() << getIndices() << " : "
483  << getComposite().getType();
484 }
485 
486 LogicalResult spirv::CompositeExtractOp::verify() {
487  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
488  auto resultType =
489  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
490  if (!resultType)
491  return failure();
492 
493  if (resultType != getType()) {
494  return emitOpError("invalid result type: expected ")
495  << resultType << " but provided " << getType();
496  }
497 
498  return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // spirv.CompositeInsert
503 //===----------------------------------------------------------------------===//
504 
505 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
506  Value object, Value composite,
507  ArrayRef<int32_t> indices) {
508  auto indexAttr = builder.getI32ArrayAttr(indices);
509  build(builder, state, composite.getType(), object, composite, indexAttr);
510 }
511 
512 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
513  OperationState &result) {
515  Type objectType, compositeType;
516  Attribute indicesAttr;
517  StringRef indicesAttrName =
518  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
519  auto loc = parser.getCurrentLocation();
520 
521  return failure(
522  parser.parseOperandList(operands, 2) ||
523  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
524  parser.parseColonType(objectType) ||
525  parser.parseKeywordType("into", compositeType) ||
526  parser.resolveOperands(operands, {objectType, compositeType}, loc,
527  result.operands) ||
528  parser.addTypesToList(compositeType, result.types));
529 }
530 
531 LogicalResult spirv::CompositeInsertOp::verify() {
532  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
533  auto objectType =
534  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
535  if (!objectType)
536  return failure();
537 
538  if (objectType != getObject().getType()) {
539  return emitOpError("object operand type should be ")
540  << objectType << ", but found " << getObject().getType();
541  }
542 
543  if (getComposite().getType() != getType()) {
544  return emitOpError("result type should be the same as "
545  "the composite type, but found ")
546  << getComposite().getType() << " vs " << getType();
547  }
548 
549  return success();
550 }
551 
553  printer << " " << getObject() << ", " << getComposite() << getIndices()
554  << " : " << getObject().getType() << " into "
555  << getComposite().getType();
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // spirv.Constant
560 //===----------------------------------------------------------------------===//
561 
562 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
563  OperationState &result) {
564  Attribute value;
565  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
566  if (parser.parseAttribute(value, valueAttrName, result.attributes))
567  return failure();
568 
569  Type type = NoneType::get(parser.getContext());
570  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
571  type = typedAttr.getType();
572  if (llvm::isa<NoneType, TensorType>(type)) {
573  if (parser.parseColonType(type))
574  return failure();
575  }
576 
577  return parser.addTypeToList(type, result.types);
578 }
579 
580 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
581  printer << ' ' << getValue();
582  if (llvm::isa<spirv::ArrayType>(getType()))
583  printer << " : " << getType();
584 }
585 
586 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
587  Type opType) {
588  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
589  auto valueType = llvm::cast<TypedAttr>(value).getType();
590  if (valueType != opType)
591  return op.emitOpError("result type (")
592  << opType << ") does not match value type (" << valueType << ")";
593  return success();
594  }
595  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
596  auto valueType = llvm::cast<TypedAttr>(value).getType();
597  if (valueType == opType)
598  return success();
599  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
600  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
601  if (!arrayType)
602  return op.emitOpError("result or element type (")
603  << opType << ") does not match value type (" << valueType
604  << "), must be the same or spirv.array";
605 
606  int numElements = arrayType.getNumElements();
607  auto opElemType = arrayType.getElementType();
608  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
609  numElements *= t.getNumElements();
610  opElemType = t.getElementType();
611  }
612  if (!opElemType.isIntOrFloat())
613  return op.emitOpError("only support nested array result type");
614 
615  auto valueElemType = shapedType.getElementType();
616  if (valueElemType != opElemType) {
617  return op.emitOpError("result element type (")
618  << opElemType << ") does not match value element type ("
619  << valueElemType << ")";
620  }
621 
622  if (numElements != shapedType.getNumElements()) {
623  return op.emitOpError("result number of elements (")
624  << numElements << ") does not match value number of elements ("
625  << shapedType.getNumElements() << ")";
626  }
627  return success();
628  }
629  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
630  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
631  if (!arrayType)
632  return op.emitOpError(
633  "must have spirv.array result type for array value");
634  Type elemType = arrayType.getElementType();
635  for (Attribute element : arrayAttr.getValue()) {
636  // Verify array elements recursively.
637  if (failed(verifyConstantType(op, element, elemType)))
638  return failure();
639  }
640  return success();
641  }
642  return op.emitOpError("cannot have attribute: ") << value;
643 }
644 
645 LogicalResult spirv::ConstantOp::verify() {
646  // ODS already generates checks to make sure the result type is valid. We just
647  // need to additionally check that the value's attribute type is consistent
648  // with the result type.
649  return verifyConstantType(*this, getValueAttr(), getType());
650 }
651 
652 bool spirv::ConstantOp::isBuildableWith(Type type) {
653  // Must be valid SPIR-V type first.
654  if (!llvm::isa<spirv::SPIRVType>(type))
655  return false;
656 
657  if (isa<SPIRVDialect>(type.getDialect())) {
658  // TODO: support constant struct
659  return llvm::isa<spirv::ArrayType>(type);
660  }
661 
662  return true;
663 }
664 
665 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
666  OpBuilder &builder) {
667  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
668  unsigned width = intType.getWidth();
669  if (width == 1)
670  return builder.create<spirv::ConstantOp>(loc, type,
671  builder.getBoolAttr(false));
672  return builder.create<spirv::ConstantOp>(
673  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
674  }
675  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
676  return builder.create<spirv::ConstantOp>(
677  loc, type, builder.getFloatAttr(floatType, 0.0));
678  }
679  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
680  Type elemType = vectorType.getElementType();
681  if (llvm::isa<IntegerType>(elemType)) {
682  return builder.create<spirv::ConstantOp>(
683  loc, type,
684  DenseElementsAttr::get(vectorType,
685  IntegerAttr::get(elemType, 0).getValue()));
686  }
687  if (llvm::isa<FloatType>(elemType)) {
688  return builder.create<spirv::ConstantOp>(
689  loc, type,
690  DenseFPElementsAttr::get(vectorType,
691  FloatAttr::get(elemType, 0.0).getValue()));
692  }
693  }
694 
695  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
696 }
697 
698 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
699  OpBuilder &builder) {
700  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
701  unsigned width = intType.getWidth();
702  if (width == 1)
703  return builder.create<spirv::ConstantOp>(loc, type,
704  builder.getBoolAttr(true));
705  return builder.create<spirv::ConstantOp>(
706  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
707  }
708  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
709  return builder.create<spirv::ConstantOp>(
710  loc, type, builder.getFloatAttr(floatType, 1.0));
711  }
712  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
713  Type elemType = vectorType.getElementType();
714  if (llvm::isa<IntegerType>(elemType)) {
715  return builder.create<spirv::ConstantOp>(
716  loc, type,
717  DenseElementsAttr::get(vectorType,
718  IntegerAttr::get(elemType, 1).getValue()));
719  }
720  if (llvm::isa<FloatType>(elemType)) {
721  return builder.create<spirv::ConstantOp>(
722  loc, type,
723  DenseFPElementsAttr::get(vectorType,
724  FloatAttr::get(elemType, 1.0).getValue()));
725  }
726  }
727 
728  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
729 }
730 
731 void mlir::spirv::ConstantOp::getAsmResultNames(
732  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
733  Type type = getType();
734 
735  SmallString<32> specialNameBuffer;
736  llvm::raw_svector_ostream specialName(specialNameBuffer);
737  specialName << "cst";
738 
739  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
740 
741  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
742  if (intTy && intTy.getWidth() == 1) {
743  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
744  }
745 
746  if (intTy.isSignless()) {
747  specialName << intCst.getInt();
748  } else if (intTy.isUnsigned()) {
749  specialName << intCst.getUInt();
750  } else {
751  specialName << intCst.getSInt();
752  }
753  }
754 
755  if (intTy || llvm::isa<FloatType>(type)) {
756  specialName << '_' << type;
757  }
758 
759  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
760  specialName << "_vec_";
761  specialName << vecType.getDimSize(0);
762 
763  Type elementType = vecType.getElementType();
764 
765  if (llvm::isa<IntegerType>(elementType) ||
766  llvm::isa<FloatType>(elementType)) {
767  specialName << "x" << elementType;
768  }
769  }
770 
771  setNameFn(getResult(), specialName.str());
772 }
773 
774 void mlir::spirv::AddressOfOp::getAsmResultNames(
775  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
776  SmallString<32> specialNameBuffer;
777  llvm::raw_svector_ostream specialName(specialNameBuffer);
778  specialName << getVariable() << "_addr";
779  setNameFn(getResult(), specialName.str());
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // spirv.ControlBarrierOp
784 //===----------------------------------------------------------------------===//
785 
786 LogicalResult spirv::ControlBarrierOp::verify() {
787  return verifyMemorySemantics(getOperation(), getMemorySemantics());
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // spirv.EntryPoint
792 //===----------------------------------------------------------------------===//
793 
794 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
795  spirv::ExecutionModel executionModel,
796  spirv::FuncOp function,
797  ArrayRef<Attribute> interfaceVars) {
798  build(builder, state,
799  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
800  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
801 }
802 
803 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
804  OperationState &result) {
805  spirv::ExecutionModel execModel;
807  SmallVector<Type, 0> idTypes;
808  SmallVector<Attribute, 4> interfaceVars;
809 
811  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
812  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
813  return failure();
814  }
815 
816  if (!parser.parseOptionalComma()) {
817  // Parse the interface variables
818  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
819  // The name of the interface variable attribute isnt important
820  FlatSymbolRefAttr var;
821  NamedAttrList attrs;
822  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
823  return failure();
824  interfaceVars.push_back(var);
825  return success();
826  }))
827  return failure();
828  }
829  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
830  parser.getBuilder().getArrayAttr(interfaceVars));
831  return success();
832 }
833 
835  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
836  printer.printSymbolName(getFn());
837  auto interfaceVars = getInterface().getValue();
838  if (!interfaceVars.empty()) {
839  printer << ", ";
840  llvm::interleaveComma(interfaceVars, printer);
841  }
842 }
843 
844 LogicalResult spirv::EntryPointOp::verify() {
845  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
846  // verification.
847  return success();
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // spirv.ExecutionMode
852 //===----------------------------------------------------------------------===//
853 
854 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
855  spirv::FuncOp function,
856  spirv::ExecutionMode executionMode,
857  ArrayRef<int32_t> params) {
858  build(builder, state, SymbolRefAttr::get(function),
859  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
860  builder.getI32ArrayAttr(params));
861 }
862 
863 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
864  OperationState &result) {
865  spirv::ExecutionMode execMode;
866  Attribute fn;
867  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
868  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
869  return failure();
870  }
871 
873  Type i32Type = parser.getBuilder().getIntegerType(32);
874  while (!parser.parseOptionalComma()) {
875  NamedAttrList attr;
876  Attribute value;
877  if (parser.parseAttribute(value, i32Type, "value", attr)) {
878  return failure();
879  }
880  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
881  }
882  StringRef valuesAttrName =
883  spirv::ExecutionModeOp::getValuesAttrName(result.name);
884  result.addAttribute(valuesAttrName,
885  parser.getBuilder().getI32ArrayAttr(values));
886  return success();
887 }
888 
890  printer << " ";
891  printer.printSymbolName(getFn());
892  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
893  auto values = this->getValues();
894  if (values.empty())
895  return;
896  printer << ", ";
897  llvm::interleaveComma(values, printer, [&](Attribute a) {
898  printer << llvm::cast<IntegerAttr>(a).getInt();
899  });
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // spirv.func
904 //===----------------------------------------------------------------------===//
905 
906 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
908  SmallVector<DictionaryAttr> resultAttrs;
909  SmallVector<Type> resultTypes;
910  auto &builder = parser.getBuilder();
911 
912  // Parse the name as a symbol.
913  StringAttr nameAttr;
914  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
915  result.attributes))
916  return failure();
917 
918  // Parse the function signature.
919  bool isVariadic = false;
921  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
922  resultAttrs))
923  return failure();
924 
925  SmallVector<Type> argTypes;
926  for (auto &arg : entryArgs)
927  argTypes.push_back(arg.type);
928  auto fnType = builder.getFunctionType(argTypes, resultTypes);
929  result.addAttribute(getFunctionTypeAttrName(result.name),
930  TypeAttr::get(fnType));
931 
932  // Parse the optional function control keyword.
933  spirv::FunctionControl fnControl;
934  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
935  return failure();
936 
937  // If additional attributes are present, parse them.
938  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
939  return failure();
940 
941  // Add the attributes to the function arguments.
942  assert(resultAttrs.size() == resultTypes.size());
944  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
945  getResAttrsAttrName(result.name));
946 
947  // Parse the optional function body.
948  auto *body = result.addRegion();
949  OptionalParseResult parseResult =
950  parser.parseOptionalRegion(*body, entryArgs);
951  return failure(parseResult.has_value() && failed(*parseResult));
952 }
953 
954 void spirv::FuncOp::print(OpAsmPrinter &printer) {
955  // Print function name, signature, and control.
956  printer << " ";
957  printer.printSymbolName(getSymName());
958  auto fnType = getFunctionType();
960  printer, *this, fnType.getInputs(),
961  /*isVariadic=*/false, fnType.getResults());
962  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
963  << "\"";
965  printer, *this,
966  {spirv::attributeName<spirv::FunctionControl>(),
967  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
968  getFunctionControlAttrName()});
969 
970  // Print the body if this is not an external function.
971  Region &body = this->getBody();
972  if (!body.empty()) {
973  printer << ' ';
974  printer.printRegion(body, /*printEntryBlockArgs=*/false,
975  /*printBlockTerminators=*/true);
976  }
977 }
978 
979 LogicalResult spirv::FuncOp::verifyType() {
980  FunctionType fnType = getFunctionType();
981  if (fnType.getNumResults() > 1)
982  return emitOpError("cannot have more than one result");
983 
984  auto hasDecorationAttr = [&](spirv::Decoration decoration,
985  unsigned argIndex) {
986  auto func = llvm::cast<FunctionOpInterface>(getOperation());
987  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
988  if (argAttr.getName() != spirv::DecorationAttr::name)
989  continue;
990  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
991  return decAttr.getValue() == decoration;
992  }
993  return false;
994  };
995 
996  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
997  Type param = fnType.getInputs()[i];
998  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
999  if (!inputPtrType)
1000  continue;
1001 
1002  auto pointeePtrType =
1003  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1004  if (pointeePtrType) {
1005  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1006  // > If an OpFunctionParameter is a pointer (or contains a pointer)
1007  // > and the type it points to is a pointer in the PhysicalStorageBuffer
1008  // > storage class, the function parameter must be decorated with exactly
1009  // > one of AliasedPointer or RestrictPointer.
1010  if (pointeePtrType.getStorageClass() !=
1011  spirv::StorageClass::PhysicalStorageBuffer)
1012  continue;
1013 
1014  bool hasAliasedPtr =
1015  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1016  bool hasRestrictPtr =
1017  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1018  if (!hasAliasedPtr && !hasRestrictPtr)
1019  return emitOpError()
1020  << "with a pointer points to a physical buffer pointer must "
1021  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1022  continue;
1023  }
1024  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1025  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1026  // > the PhysicalStorageBuffer storage class, the function parameter must
1027  // > be decorated with exactly one of Aliased or Restrict.
1028  if (auto pointeeArrayType =
1029  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1030  pointeePtrType =
1031  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1032  } else {
1033  pointeePtrType = inputPtrType;
1034  }
1035 
1036  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1037  spirv::StorageClass::PhysicalStorageBuffer)
1038  continue;
1039 
1040  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1041  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1042  if (!hasAliased && !hasRestrict)
1043  return emitOpError() << "with physical buffer pointer must be decorated "
1044  "either 'Aliased' or 'Restrict'";
1045  }
1046 
1047  return success();
1048 }
1049 
1050 LogicalResult spirv::FuncOp::verifyBody() {
1051  FunctionType fnType = getFunctionType();
1052 
1053  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1054  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1055  if (fnType.getNumResults() != 0)
1056  return retOp.emitOpError("cannot be used in functions returning value");
1057  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1058  if (fnType.getNumResults() != 1)
1059  return retOp.emitOpError(
1060  "returns 1 value but enclosing function requires ")
1061  << fnType.getNumResults() << " results";
1062 
1063  auto retOperandType = retOp.getValue().getType();
1064  auto fnResultType = fnType.getResult(0);
1065  if (retOperandType != fnResultType)
1066  return retOp.emitOpError(" return value's type (")
1067  << retOperandType << ") mismatch with function's result type ("
1068  << fnResultType << ")";
1069  }
1070  return WalkResult::advance();
1071  });
1072 
1073  // TODO: verify other bits like linkage type.
1074 
1075  return failure(walkResult.wasInterrupted());
1076 }
1077 
1078 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1079  StringRef name, FunctionType type,
1080  spirv::FunctionControl control,
1081  ArrayRef<NamedAttribute> attrs) {
1082  state.addAttribute(SymbolTable::getSymbolAttrName(),
1083  builder.getStringAttr(name));
1084  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1085  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1086  builder.getAttr<spirv::FunctionControlAttr>(control));
1087  state.attributes.append(attrs.begin(), attrs.end());
1088  state.addRegion();
1089 }
1090 
1091 //===----------------------------------------------------------------------===//
1092 // spirv.GLFClampOp
1093 //===----------------------------------------------------------------------===//
1094 
1095 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1096  OperationState &result) {
1097  return parseOneResultSameOperandTypeOp(parser, result);
1098 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // spirv.GLUClampOp
1103 //===----------------------------------------------------------------------===//
1104 
1105 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1106  OperationState &result) {
1107  return parseOneResultSameOperandTypeOp(parser, result);
1108 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // spirv.GLSClampOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1116  OperationState &result) {
1117  return parseOneResultSameOperandTypeOp(parser, result);
1118 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // spirv.GLFmaOp
1123 //===----------------------------------------------------------------------===//
1124 
1125 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1126  return parseOneResultSameOperandTypeOp(parser, result);
1127 }
1128 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // spirv.GlobalVariable
1132 //===----------------------------------------------------------------------===//
1133 
1134 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1135  Type type, StringRef name,
1136  unsigned descriptorSet, unsigned binding) {
1137  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1138  state.addAttribute(
1139  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1140  builder.getI32IntegerAttr(descriptorSet));
1141  state.addAttribute(
1142  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1143  builder.getI32IntegerAttr(binding));
1144 }
1145 
1146 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1147  Type type, StringRef name,
1148  spirv::BuiltIn builtin) {
1149  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1150  state.addAttribute(
1151  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1152  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1153 }
1154 
1155 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1156  OperationState &result) {
1157  // Parse variable name.
1158  StringAttr nameAttr;
1159  StringRef initializerAttrName =
1160  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1161  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1162  result.attributes)) {
1163  return failure();
1164  }
1165 
1166  // Parse optional initializer
1167  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1168  FlatSymbolRefAttr initSymbol;
1169  if (parser.parseLParen() ||
1170  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1171  result.attributes) ||
1172  parser.parseRParen())
1173  return failure();
1174  }
1175 
1176  if (parseVariableDecorations(parser, result)) {
1177  return failure();
1178  }
1179 
1180  Type type;
1181  StringRef typeAttrName =
1182  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1183  auto loc = parser.getCurrentLocation();
1184  if (parser.parseColonType(type)) {
1185  return failure();
1186  }
1187  if (!llvm::isa<spirv::PointerType>(type)) {
1188  return parser.emitError(loc, "expected spirv.ptr type");
1189  }
1190  result.addAttribute(typeAttrName, TypeAttr::get(type));
1191 
1192  return success();
1193 }
1194 
1196  SmallVector<StringRef, 4> elidedAttrs{
1197  spirv::attributeName<spirv::StorageClass>()};
1198 
1199  // Print variable name.
1200  printer << ' ';
1201  printer.printSymbolName(getSymName());
1202  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1203 
1204  StringRef initializerAttrName = this->getInitializerAttrName();
1205  // Print optional initializer
1206  if (auto initializer = this->getInitializer()) {
1207  printer << " " << initializerAttrName << '(';
1208  printer.printSymbolName(*initializer);
1209  printer << ')';
1210  elidedAttrs.push_back(initializerAttrName);
1211  }
1212 
1213  StringRef typeAttrName = this->getTypeAttrName();
1214  elidedAttrs.push_back(typeAttrName);
1215  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1216  printer << " : " << getType();
1217 }
1218 
1219 LogicalResult spirv::GlobalVariableOp::verify() {
1220  if (!llvm::isa<spirv::PointerType>(getType()))
1221  return emitOpError("result must be of a !spv.ptr type");
1222 
1223  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1224  // object. It cannot be Generic. It must be the same as the Storage Class
1225  // operand of the Result Type."
1226  // Also, Function storage class is reserved by spirv.Variable.
1227  auto storageClass = this->storageClass();
1228  if (storageClass == spirv::StorageClass::Generic ||
1229  storageClass == spirv::StorageClass::Function) {
1230  return emitOpError("storage class cannot be '")
1231  << stringifyStorageClass(storageClass) << "'";
1232  }
1233 
1234  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1235  this->getInitializerAttrName())) {
1237  (*this)->getParentOp(), init.getAttr());
1238  // TODO: Currently only variable initialization with specialization
1239  // constants and other variables is supported. They could be normal
1240  // constants in the module scope as well.
1241  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1242  spirv::SpecConstantCompositeOp>(initOp)) {
1243  return emitOpError("initializer must be result of a "
1244  "spirv.SpecConstant or spirv.GlobalVariable or "
1245  "spirv.SpecConstantCompositeOp op");
1246  }
1247  }
1248 
1249  return success();
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // spirv.INTEL.SubgroupBlockRead
1254 //===----------------------------------------------------------------------===//
1255 
1257  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1258  return failure();
1259 
1260  return success();
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // spirv.INTEL.SubgroupBlockWrite
1265 //===----------------------------------------------------------------------===//
1266 
1268  OperationState &result) {
1269  // Parse the storage class specification
1270  spirv::StorageClass storageClass;
1272  auto loc = parser.getCurrentLocation();
1273  Type elementType;
1274  if (parseEnumStrAttr(storageClass, parser) ||
1275  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1276  parser.parseType(elementType)) {
1277  return failure();
1278  }
1279 
1280  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1281  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1282  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1283 
1284  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1285  result.operands)) {
1286  return failure();
1287  }
1288  return success();
1289 }
1290 
1292  printer << " " << getPtr() << ", " << getValue() << " : "
1293  << getValue().getType();
1294 }
1295 
1297  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1298  return failure();
1299 
1300  return success();
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // spirv.IAddCarryOp
1305 //===----------------------------------------------------------------------===//
1306 
1307 LogicalResult spirv::IAddCarryOp::verify() {
1309 }
1310 
1311 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1312  OperationState &result) {
1314 }
1315 
1316 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1317  ::printArithmeticExtendedBinaryOp(*this, printer);
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // spirv.ISubBorrowOp
1322 //===----------------------------------------------------------------------===//
1323 
1324 LogicalResult spirv::ISubBorrowOp::verify() {
1326 }
1327 
1328 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1329  OperationState &result) {
1331 }
1332 
1334  ::printArithmeticExtendedBinaryOp(*this, printer);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // spirv.SMulExtended
1339 //===----------------------------------------------------------------------===//
1340 
1341 LogicalResult spirv::SMulExtendedOp::verify() {
1343 }
1344 
1345 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1346  OperationState &result) {
1348 }
1349 
1351  ::printArithmeticExtendedBinaryOp(*this, printer);
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // spirv.UMulExtended
1356 //===----------------------------------------------------------------------===//
1357 
1358 LogicalResult spirv::UMulExtendedOp::verify() {
1360 }
1361 
1362 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1363  OperationState &result) {
1365 }
1366 
1368  ::printArithmeticExtendedBinaryOp(*this, printer);
1369 }
1370 
1371 //===----------------------------------------------------------------------===//
1372 // spirv.MemoryBarrierOp
1373 //===----------------------------------------------------------------------===//
1374 
1375 LogicalResult spirv::MemoryBarrierOp::verify() {
1376  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // spirv.module
1381 //===----------------------------------------------------------------------===//
1382 
1383 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1384  std::optional<StringRef> name) {
1385  OpBuilder::InsertionGuard guard(builder);
1386  builder.createBlock(state.addRegion());
1387  if (name) {
1388  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1389  builder.getStringAttr(*name));
1390  }
1391 }
1392 
1393 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1394  spirv::AddressingModel addressingModel,
1395  spirv::MemoryModel memoryModel,
1396  std::optional<VerCapExtAttr> vceTriple,
1397  std::optional<StringRef> name) {
1398  state.addAttribute(
1399  "addressing_model",
1400  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1401  state.addAttribute("memory_model",
1402  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1403  OpBuilder::InsertionGuard guard(builder);
1404  builder.createBlock(state.addRegion());
1405  if (vceTriple)
1406  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1407  if (name)
1408  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1409  builder.getStringAttr(*name));
1410 }
1411 
1412 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1413  OperationState &result) {
1414  Region *body = result.addRegion();
1415 
1416  // If the name is present, parse it.
1417  StringAttr nameAttr;
1418  (void)parser.parseOptionalSymbolName(
1419  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1420 
1421  // Parse attributes
1422  spirv::AddressingModel addrModel;
1423  spirv::MemoryModel memoryModel;
1424  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1425  result) ||
1426  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1427  result))
1428  return failure();
1429 
1430  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1431  spirv::VerCapExtAttr vceTriple;
1432  if (parser.parseAttribute(vceTriple,
1433  spirv::ModuleOp::getVCETripleAttrName(),
1434  result.attributes))
1435  return failure();
1436  }
1437 
1438  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1439  parser.parseRegion(*body, /*arguments=*/{}))
1440  return failure();
1441 
1442  // Make sure we have at least one block.
1443  if (body->empty())
1444  body->push_back(new Block());
1445 
1446  return success();
1447 }
1448 
1449 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1450  if (std::optional<StringRef> name = getName()) {
1451  printer << ' ';
1452  printer.printSymbolName(*name);
1453  }
1454 
1455  SmallVector<StringRef, 2> elidedAttrs;
1456 
1457  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1458  << spirv::stringifyMemoryModel(getMemoryModel());
1459  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1460  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1461  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1463 
1464  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1465  printer << " requires " << *triple;
1466  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1467  }
1468 
1469  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1470  printer << ' ';
1471  printer.printRegion(getRegion());
1472 }
1473 
1474 LogicalResult spirv::ModuleOp::verifyRegions() {
1475  Dialect *dialect = (*this)->getDialect();
1477  entryPoints;
1478  mlir::SymbolTable table(*this);
1479 
1480  for (auto &op : *getBody()) {
1481  if (op.getDialect() != dialect)
1482  return op.emitError("'spirv.module' can only contain spirv.* ops");
1483 
1484  // For EntryPoint op, check that the function and execution model is not
1485  // duplicated in EntryPointOps. Also verify that the interface specified
1486  // comes from globalVariables here to make this check cheaper.
1487  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1488  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1489  if (!funcOp) {
1490  return entryPointOp.emitError("function '")
1491  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1492  }
1493  if (auto interface = entryPointOp.getInterface()) {
1494  for (Attribute varRef : interface) {
1495  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1496  if (!varSymRef) {
1497  return entryPointOp.emitError(
1498  "expected symbol reference for interface "
1499  "specification instead of '")
1500  << varRef;
1501  }
1502  auto variableOp =
1503  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1504  if (!variableOp) {
1505  return entryPointOp.emitError("expected spirv.GlobalVariable "
1506  "symbol reference instead of'")
1507  << varSymRef << "'";
1508  }
1509  }
1510  }
1511 
1512  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1513  funcOp, entryPointOp.getExecutionModel());
1514  if (!entryPoints.try_emplace(key, entryPointOp).second)
1515  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1516  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1517  // If the function is external and does not have 'Import'
1518  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1519  // LinkageAttributes is used to import external functions.
1520  auto linkageAttr = funcOp.getLinkageAttributes();
1521  auto hasImportLinkage =
1522  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1523  spirv::LinkageType::Import);
1524  if (funcOp.isExternal() && !hasImportLinkage)
1525  return op.emitError(
1526  "'spirv.module' cannot contain external functions "
1527  "without 'Import' linkage_attributes (LinkageAttributes)");
1528 
1529  // TODO: move this check to spirv.func.
1530  for (auto &block : funcOp)
1531  for (auto &op : block) {
1532  if (op.getDialect() != dialect)
1533  return op.emitError(
1534  "functions in 'spirv.module' can only contain spirv.* ops");
1535  }
1536  }
1537  }
1538 
1539  return success();
1540 }
1541 
1542 //===----------------------------------------------------------------------===//
1543 // spirv.mlir.referenceof
1544 //===----------------------------------------------------------------------===//
1545 
1546 LogicalResult spirv::ReferenceOfOp::verify() {
1547  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1548  (*this)->getParentOp(), getSpecConstAttr());
1549  Type constType;
1550 
1551  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1552  if (specConstOp)
1553  constType = specConstOp.getDefaultValue().getType();
1554 
1555  auto specConstCompositeOp =
1556  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1557  if (specConstCompositeOp)
1558  constType = specConstCompositeOp.getType();
1559 
1560  if (!specConstOp && !specConstCompositeOp)
1561  return emitOpError(
1562  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1563 
1564  if (getReference().getType() != constType)
1565  return emitOpError("result type mismatch with the referenced "
1566  "specialization constant's type");
1567 
1568  return success();
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // spirv.SpecConstant
1573 //===----------------------------------------------------------------------===//
1574 
1575 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1576  OperationState &result) {
1577  StringAttr nameAttr;
1578  Attribute valueAttr;
1579  StringRef defaultValueAttrName =
1580  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1581 
1582  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1583  result.attributes))
1584  return failure();
1585 
1586  // Parse optional spec_id.
1587  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1588  IntegerAttr specIdAttr;
1589  if (parser.parseLParen() ||
1590  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1591  parser.parseRParen())
1592  return failure();
1593  }
1594 
1595  if (parser.parseEqual() ||
1596  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1597  return failure();
1598 
1599  return success();
1600 }
1601 
1603  printer << ' ';
1604  printer.printSymbolName(getSymName());
1605  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1606  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1607  printer << " = " << getDefaultValue();
1608 }
1609 
1610 LogicalResult spirv::SpecConstantOp::verify() {
1611  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1612  if (specID.getValue().isNegative())
1613  return emitOpError("SpecId cannot be negative");
1614 
1615  auto value = getDefaultValue();
1616  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1617  // Make sure bitwidth is allowed.
1618  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1619  return emitOpError("default value bitwidth disallowed");
1620  return success();
1621  }
1622  return emitOpError(
1623  "default value can only be a bool, integer, or float scalar");
1624 }
1625 
1626 //===----------------------------------------------------------------------===//
1627 // spirv.VectorShuffle
1628 //===----------------------------------------------------------------------===//
1629 
1630 LogicalResult spirv::VectorShuffleOp::verify() {
1631  VectorType resultType = llvm::cast<VectorType>(getType());
1632 
1633  size_t numResultElements = resultType.getNumElements();
1634  if (numResultElements != getComponents().size())
1635  return emitOpError("result type element count (")
1636  << numResultElements
1637  << ") mismatch with the number of component selectors ("
1638  << getComponents().size() << ")";
1639 
1640  size_t totalSrcElements =
1641  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1642  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1643 
1644  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1645  uint32_t index = selector.getZExtValue();
1646  if (index >= totalSrcElements &&
1647  index != std::numeric_limits<uint32_t>().max())
1648  return emitOpError("component selector ")
1649  << index << " out of range: expected to be in [0, "
1650  << totalSrcElements << ") or 0xffffffff";
1651  }
1652  return success();
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // spirv.MatrixTimesScalar
1657 //===----------------------------------------------------------------------===//
1658 
1659 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1660  Type elementType =
1661  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1663  [](auto matrixType) { return matrixType.getElementType(); })
1664  .Default([](Type) { return nullptr; });
1665 
1666  assert(elementType && "Unhandled type");
1667 
1668  // Check that the scalar type is the same as the matrix element type.
1669  if (getScalar().getType() != elementType)
1670  return emitOpError("input matrix components' type and scaling value must "
1671  "have the same type");
1672 
1673  return success();
1674 }
1675 
1676 //===----------------------------------------------------------------------===//
1677 // spirv.Transpose
1678 //===----------------------------------------------------------------------===//
1679 
1680 LogicalResult spirv::TransposeOp::verify() {
1681  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1682  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1683 
1684  // Verify that the input and output matrices have correct shapes.
1685  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1686  return emitError("input matrix rows count must be equal to "
1687  "output matrix columns count");
1688 
1689  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1690  return emitError("input matrix columns count must be equal to "
1691  "output matrix rows count");
1692 
1693  // Verify that the input and output matrices have the same component type
1694  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1695  return emitError("input and output matrices must have the same "
1696  "component type");
1697 
1698  return success();
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // spirv.MatrixTimesMatrix
1703 //===----------------------------------------------------------------------===//
1704 
1705 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1706  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1707  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1708  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1709 
1710  // left matrix columns' count and right matrix rows' count must be equal
1711  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1712  return emitError("left matrix columns' count must be equal to "
1713  "the right matrix rows' count");
1714 
1715  // right and result matrices columns' count must be the same
1716  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1717  return emitError(
1718  "right and result matrices must have equal columns' count");
1719 
1720  // right and result matrices component type must be the same
1721  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1722  return emitError("right and result matrices' component type must"
1723  " be the same");
1724 
1725  // left and result matrices component type must be the same
1726  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1727  return emitError("left and result matrices' component type"
1728  " must be the same");
1729 
1730  // left and result matrices rows count must be the same
1731  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1732  return emitError("left and result matrices must have equal rows' count");
1733 
1734  return success();
1735 }
1736 
1737 //===----------------------------------------------------------------------===//
1738 // spirv.SpecConstantComposite
1739 //===----------------------------------------------------------------------===//
1740 
1742  OperationState &result) {
1743 
1744  StringAttr compositeName;
1745  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1746  result.attributes))
1747  return failure();
1748 
1749  if (parser.parseLParen())
1750  return failure();
1751 
1752  SmallVector<Attribute, 4> constituents;
1753 
1754  do {
1755  // The name of the constituent attribute isn't important
1756  const char *attrName = "spec_const";
1757  FlatSymbolRefAttr specConstRef;
1758  NamedAttrList attrs;
1759 
1760  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1761  return failure();
1762 
1763  constituents.push_back(specConstRef);
1764  } while (!parser.parseOptionalComma());
1765 
1766  if (parser.parseRParen())
1767  return failure();
1768 
1769  StringAttr compositeSpecConstituentsName =
1770  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1771  result.addAttribute(compositeSpecConstituentsName,
1772  parser.getBuilder().getArrayAttr(constituents));
1773 
1774  Type type;
1775  if (parser.parseColonType(type))
1776  return failure();
1777 
1778  StringAttr typeAttrName =
1779  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1780  result.addAttribute(typeAttrName, TypeAttr::get(type));
1781 
1782  return success();
1783 }
1784 
1786  printer << " ";
1787  printer.printSymbolName(getSymName());
1788  printer << " (";
1789  auto constituents = this->getConstituents().getValue();
1790 
1791  if (!constituents.empty())
1792  llvm::interleaveComma(constituents, printer);
1793 
1794  printer << ") : " << getType();
1795 }
1796 
1798  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1799  auto constituents = this->getConstituents().getValue();
1800 
1801  if (!cType)
1802  return emitError("result type must be a composite type, but provided ")
1803  << getType();
1804 
1805  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1806  return emitError("unsupported composite type ") << cType;
1807  if (constituents.size() != cType.getNumElements())
1808  return emitError("has incorrect number of operands: expected ")
1809  << cType.getNumElements() << ", but provided "
1810  << constituents.size();
1811 
1812  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1813  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1814 
1815  auto constituentSpecConstOp =
1816  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1817  (*this)->getParentOp(), constituent.getAttr()));
1818 
1819  if (constituentSpecConstOp.getDefaultValue().getType() !=
1820  cType.getElementType(index))
1821  return emitError("has incorrect types of operands: expected ")
1822  << cType.getElementType(index) << ", but provided "
1823  << constituentSpecConstOp.getDefaultValue().getType();
1824  }
1825 
1826  return success();
1827 }
1828 
1829 //===----------------------------------------------------------------------===//
1830 // spirv.SpecConstantOperation
1831 //===----------------------------------------------------------------------===//
1832 
1834  OperationState &result) {
1835  Region *body = result.addRegion();
1836 
1837  if (parser.parseKeyword("wraps"))
1838  return failure();
1839 
1840  body->push_back(new Block);
1841  Block &block = body->back();
1842  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1843 
1844  if (!wrappedOp)
1845  return failure();
1846 
1847  OpBuilder builder(parser.getContext());
1848  builder.setInsertionPointToEnd(&block);
1849  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1850  result.location = wrappedOp->getLoc();
1851 
1852  result.addTypes(wrappedOp->getResult(0).getType());
1853 
1854  if (parser.parseOptionalAttrDict(result.attributes))
1855  return failure();
1856 
1857  return success();
1858 }
1859 
1861  printer << " wraps ";
1862  printer.printGenericOp(&getBody().front().front());
1863 }
1864 
1865 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1866  Block &block = getRegion().getBlocks().front();
1867 
1868  if (block.getOperations().size() != 2)
1869  return emitOpError("expected exactly 2 nested ops");
1870 
1871  Operation &enclosedOp = block.getOperations().front();
1872 
1874  return emitOpError("invalid enclosed op");
1875 
1876  for (auto operand : enclosedOp.getOperands())
1877  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1878  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1879  return emitOpError(
1880  "invalid operand, must be defined by a constant operation");
1881 
1882  return success();
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // spirv.GL.FrexpStruct
1887 //===----------------------------------------------------------------------===//
1888 
1889 LogicalResult spirv::GLFrexpStructOp::verify() {
1890  spirv::StructType structTy =
1891  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1892 
1893  if (structTy.getNumElements() != 2)
1894  return emitError("result type must be a struct type with two memebers");
1895 
1896  Type significandTy = structTy.getElementType(0);
1897  Type exponentTy = structTy.getElementType(1);
1898  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1899  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1900 
1901  Type operandTy = getOperand().getType();
1902  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1903  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1904 
1905  if (significandTy != operandTy)
1906  return emitError("member zero of the resulting struct type must be the "
1907  "same type as the operand");
1908 
1909  if (exponentVecTy) {
1910  IntegerType componentIntTy =
1911  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1912  if (!componentIntTy || componentIntTy.getWidth() != 32)
1913  return emitError("member one of the resulting struct type must"
1914  "be a scalar or vector of 32 bit integer type");
1915  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1916  return emitError("member one of the resulting struct type "
1917  "must be a scalar or vector of 32 bit integer type");
1918  }
1919 
1920  // Check that the two member types have the same number of components
1921  if (operandVecTy && exponentVecTy &&
1922  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1923  return success();
1924 
1925  if (operandFTy && exponentIntTy)
1926  return success();
1927 
1928  return emitError("member one of the resulting struct type must have the same "
1929  "number of components as the operand type");
1930 }
1931 
1932 //===----------------------------------------------------------------------===//
1933 // spirv.GL.Ldexp
1934 //===----------------------------------------------------------------------===//
1935 
1936 LogicalResult spirv::GLLdexpOp::verify() {
1937  Type significandType = getX().getType();
1938  Type exponentType = getExp().getType();
1939 
1940  if (llvm::isa<FloatType>(significandType) !=
1941  llvm::isa<IntegerType>(exponentType))
1942  return emitOpError("operands must both be scalars or vectors");
1943 
1944  auto getNumElements = [](Type type) -> unsigned {
1945  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1946  return vectorType.getNumElements();
1947  return 1;
1948  };
1949 
1950  if (getNumElements(significandType) != getNumElements(exponentType))
1951  return emitOpError("operands must have the same number of elements");
1952 
1953  return success();
1954 }
1955 
1956 //===----------------------------------------------------------------------===//
1957 // spirv.ImageDrefGather
1958 //===----------------------------------------------------------------------===//
1959 
1960 LogicalResult spirv::ImageDrefGatherOp::verify() {
1961  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
1962  auto sampledImageType =
1963  llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
1964  auto imageType =
1965  llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
1966 
1967  if (resultType.getNumElements() != 4)
1968  return emitOpError("result type must be a vector of four components");
1969 
1970  Type elementType = resultType.getElementType();
1971  Type sampledElementType = imageType.getElementType();
1972  if (!llvm::isa<NoneType>(sampledElementType) &&
1973  elementType != sampledElementType)
1974  return emitOpError(
1975  "the component type of result must be the same as sampled type of the "
1976  "underlying image type");
1977 
1978  spirv::Dim imageDim = imageType.getDim();
1979  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
1980 
1981  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
1982  imageDim != spirv::Dim::Rect)
1983  return emitOpError(
1984  "the Dim operand of the underlying image type must be 2D, Cube, or "
1985  "Rect");
1986 
1987  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
1988  return emitOpError("the MS operand of the underlying image type must be 0");
1989 
1990  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
1991  auto operandArguments = getOperandArguments();
1992 
1993  return verifyImageOperands(*this, attr, operandArguments);
1994 }
1995 
1996 //===----------------------------------------------------------------------===//
1997 // spirv.ShiftLeftLogicalOp
1998 //===----------------------------------------------------------------------===//
1999 
2000 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2001  return verifyShiftOp(*this);
2002 }
2003 
2004 //===----------------------------------------------------------------------===//
2005 // spirv.ShiftRightArithmeticOp
2006 //===----------------------------------------------------------------------===//
2007 
2008 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2009  return verifyShiftOp(*this);
2010 }
2011 
2012 //===----------------------------------------------------------------------===//
2013 // spirv.ShiftRightLogicalOp
2014 //===----------------------------------------------------------------------===//
2015 
2016 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2017  return verifyShiftOp(*this);
2018 }
2019 
2020 //===----------------------------------------------------------------------===//
2021 // spirv.ImageQuerySize
2022 //===----------------------------------------------------------------------===//
2023 
2024 LogicalResult spirv::ImageQuerySizeOp::verify() {
2025  spirv::ImageType imageType =
2026  llvm::cast<spirv::ImageType>(getImage().getType());
2027  Type resultType = getResult().getType();
2028 
2029  spirv::Dim dim = imageType.getDim();
2030  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
2031  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
2032  switch (dim) {
2033  case spirv::Dim::Dim1D:
2034  case spirv::Dim::Dim2D:
2035  case spirv::Dim::Dim3D:
2036  case spirv::Dim::Cube:
2037  if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2038  samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2039  samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2040  return emitError(
2041  "if Dim is 1D, 2D, 3D, or Cube, "
2042  "it must also have either an MS of 1 or a Sampled of 0 or 2");
2043  break;
2044  case spirv::Dim::Buffer:
2045  case spirv::Dim::Rect:
2046  break;
2047  default:
2048  return emitError("the Dim operand of the image type must "
2049  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2050  }
2051 
2052  unsigned componentNumber = 0;
2053  switch (dim) {
2054  case spirv::Dim::Dim1D:
2055  case spirv::Dim::Buffer:
2056  componentNumber = 1;
2057  break;
2058  case spirv::Dim::Dim2D:
2059  case spirv::Dim::Cube:
2060  case spirv::Dim::Rect:
2061  componentNumber = 2;
2062  break;
2063  case spirv::Dim::Dim3D:
2064  componentNumber = 3;
2065  break;
2066  default:
2067  break;
2068  }
2069 
2070  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2071  componentNumber += 1;
2072 
2073  unsigned resultComponentNumber = 1;
2074  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2075  resultComponentNumber = resultVectorType.getNumElements();
2076 
2077  if (componentNumber != resultComponentNumber)
2078  return emitError("expected the result to have ")
2079  << componentNumber << " component(s), but found "
2080  << resultComponentNumber << " component(s)";
2081 
2082  return success();
2083 }
2084 
2085 //===----------------------------------------------------------------------===//
2086 // spirv.VectorTimesScalarOp
2087 //===----------------------------------------------------------------------===//
2088 
2089 LogicalResult spirv::VectorTimesScalarOp::verify() {
2090  if (getVector().getType() != getType())
2091  return emitOpError("vector operand and result type mismatch");
2092  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2093  if (getScalar().getType() != scalarType)
2094  return emitOpError("scalar operand and result element type match");
2095  return success();
2096 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:296
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:586
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:121
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:282
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:328
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:198
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:150
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:170
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:320
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:316
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:106
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:351
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:359
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:355
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
SPIR-V struct type.
Definition: SPIRVTypes.h:293
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:94
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:50
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.