MLIR  21.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 #include <type_traits>
42 
43 using namespace mlir;
44 using namespace mlir::spirv::AttrNames;
45 
46 //===----------------------------------------------------------------------===//
47 // Common utility functions
48 //===----------------------------------------------------------------------===//
49 
50 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
51  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
52  if (!constOp) {
53  return failure();
54  }
55  auto valueAttr = constOp.getValue();
56  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
57  if (!integerValueAttr) {
58  return failure();
59  }
60 
61  if (integerValueAttr.getType().isSignlessInteger())
62  value = integerValueAttr.getInt();
63  else
64  value = integerValueAttr.getSInt();
65 
66  return success();
67 }
68 
69 LogicalResult
71  spirv::MemorySemantics memorySemantics) {
72  // According to the SPIR-V specification:
73  // "Despite being a mask and allowing multiple bits to be combined, it is
74  // invalid for more than one of these four bits to be set: Acquire, Release,
75  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
76  // Release semantics is done by setting the AcquireRelease bit, not by setting
77  // two bits."
78  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
79  spirv::MemorySemantics::Release |
80  spirv::MemorySemantics::AcquireRelease |
81  spirv::MemorySemantics::SequentiallyConsistent;
82 
83  auto bitCount =
84  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
85  if (bitCount > 1) {
86  return op->emitError(
87  "expected at most one of these four memory constraints "
88  "to be set: `Acquire`, `Release`,"
89  "`AcquireRelease` or `SequentiallyConsistent`");
90  }
91  return success();
92 }
93 
95  SmallVectorImpl<StringRef> &elidedAttrs) {
96  // Print optional descriptor binding
97  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
98  stringifyDecoration(spirv::Decoration::DescriptorSet));
99  auto bindingName = llvm::convertToSnakeFromCamelCase(
100  stringifyDecoration(spirv::Decoration::Binding));
101  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
102  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
103  if (descriptorSet && binding) {
104  elidedAttrs.push_back(descriptorSetName);
105  elidedAttrs.push_back(bindingName);
106  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
107  << ")";
108  }
109 
110  // Print BuiltIn attribute if present
111  auto builtInName = llvm::convertToSnakeFromCamelCase(
112  stringifyDecoration(spirv::Decoration::BuiltIn));
113  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
114  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
115  elidedAttrs.push_back(builtInName);
116  }
117 
118  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
119 }
120 
121 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
122  OperationState &result) {
124  Type type;
125  // If the operand list is in-between parentheses, then we have a generic form.
126  // (see the fallback in `printOneResultOp`).
127  SMLoc loc = parser.getCurrentLocation();
128  if (!parser.parseOptionalLParen()) {
129  if (parser.parseOperandList(ops) || parser.parseRParen() ||
130  parser.parseOptionalAttrDict(result.attributes) ||
131  parser.parseColon() || parser.parseType(type))
132  return failure();
133  auto fnType = llvm::dyn_cast<FunctionType>(type);
134  if (!fnType) {
135  parser.emitError(loc, "expected function type");
136  return failure();
137  }
138  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
139  return failure();
140  result.addTypes(fnType.getResults());
141  return success();
142  }
143  return failure(parser.parseOperandList(ops) ||
144  parser.parseOptionalAttrDict(result.attributes) ||
145  parser.parseColonType(type) ||
146  parser.resolveOperands(ops, type, result.operands) ||
147  parser.addTypeToList(type, result.types));
148 }
149 
151  assert(op->getNumResults() == 1 && "op should have one result");
152 
153  // If not all the operand and result types are the same, just use the
154  // generic assembly form to avoid omitting information in printing.
155  auto resultType = op->getResult(0).getType();
156  if (llvm::any_of(op->getOperandTypes(),
157  [&](Type type) { return type != resultType; })) {
158  p.printGenericOp(op, /*printOpName=*/false);
159  return;
160  }
161 
162  p << ' ';
163  p.printOperands(op->getOperands());
165  // Now we can output only one type for all operands and the result.
166  p << " : " << resultType;
167 }
168 
169 template <typename BlockReadWriteOpTy>
170 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
171  Value ptr, Value val) {
172  auto valType = val.getType();
173  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
174  valType = valVecTy.getElementType();
175 
176  if (valType !=
177  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
178  return op.emitOpError("mismatch in result type and pointer type");
179  }
180  return success();
181 }
182 
183 /// Walks the given type hierarchy with the given indices, potentially down
184 /// to component granularity, to select an element type. Returns null type and
185 /// emits errors with the given loc on failure.
186 static Type
188  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
189  if (indices.empty()) {
190  emitErrorFn("expected at least one index for spirv.CompositeExtract");
191  return nullptr;
192  }
193 
194  for (auto index : indices) {
195  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
196  if (cType.hasCompileTimeKnownNumElements() &&
197  (index < 0 ||
198  static_cast<uint64_t>(index) >= cType.getNumElements())) {
199  emitErrorFn("index ") << index << " out of bounds for " << type;
200  return nullptr;
201  }
202  type = cType.getElementType(index);
203  } else {
204  emitErrorFn("cannot extract from non-composite type ")
205  << type << " with index " << index;
206  return nullptr;
207  }
208  }
209  return type;
210 }
211 
212 static Type
214  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
215  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
216  if (!indicesArrayAttr) {
217  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
218  return nullptr;
219  }
220  if (indicesArrayAttr.empty()) {
221  emitErrorFn("expected at least one index for spirv.CompositeExtract");
222  return nullptr;
223  }
224 
225  SmallVector<int32_t, 2> indexVals;
226  for (auto indexAttr : indicesArrayAttr) {
227  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
228  if (!indexIntAttr) {
229  emitErrorFn("expected an 32-bit integer for index, but found '")
230  << indexAttr << "'";
231  return nullptr;
232  }
233  indexVals.push_back(indexIntAttr.getInt());
234  }
235  return getElementType(type, indexVals, emitErrorFn);
236 }
237 
238 static Type getElementType(Type type, Attribute indices, Location loc) {
239  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
240  return ::mlir::emitError(loc, err);
241  };
242  return getElementType(type, indices, errorFn);
243 }
244 
245 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
246  SMLoc loc) {
247  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
248  return parser.emitError(loc, err);
249  };
250  return getElementType(type, indices, errorFn);
251 }
252 
253 template <typename ExtendedBinaryOp>
254 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
255  auto resultType = llvm::cast<spirv::StructType>(op.getType());
256  if (resultType.getNumElements() != 2)
257  return op.emitOpError("expected result struct type containing two members");
258 
259  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
260  resultType.getElementType(0),
261  resultType.getElementType(1)}))
262  return op.emitOpError(
263  "expected all operand types and struct member types are the same");
264 
265  return success();
266 }
267 
268 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
269  OperationState &result) {
271  if (parser.parseOptionalAttrDict(result.attributes) ||
272  parser.parseOperandList(operands) || parser.parseColon())
273  return failure();
274 
275  Type resultType;
276  SMLoc loc = parser.getCurrentLocation();
277  if (parser.parseType(resultType))
278  return failure();
279 
280  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
281  if (!structType || structType.getNumElements() != 2)
282  return parser.emitError(loc, "expected spirv.struct type with two members");
283 
284  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
285  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
286  return failure();
287 
288  result.addTypes(resultType);
289  return success();
290 }
291 
293  OpAsmPrinter &printer) {
294  printer << ' ';
295  printer.printOptionalAttrDict(op->getAttrs());
296  printer.printOperands(op->getOperands());
297  printer << " : " << op->getResultTypes().front();
298 }
299 
300 static LogicalResult verifyShiftOp(Operation *op) {
301  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
302  return op->emitError("expected the same type for the first operand and "
303  "result, but provided ")
304  << op->getOperand(0).getType() << " and "
305  << op->getResult(0).getType();
306  }
307  return success();
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // spirv.mlir.addressof
312 //===----------------------------------------------------------------------===//
313 
314 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
315  spirv::GlobalVariableOp var) {
316  build(builder, state, var.getType(), SymbolRefAttr::get(var));
317 }
318 
319 LogicalResult spirv::AddressOfOp::verify() {
320  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
321  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
322  getVariableAttr()));
323  if (!varOp) {
324  return emitOpError("expected spirv.GlobalVariable symbol");
325  }
326  if (getPointer().getType() != varOp.getType()) {
327  return emitOpError(
328  "result type mismatch with the referenced global variable's type");
329  }
330  return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // spirv.CompositeConstruct
335 //===----------------------------------------------------------------------===//
336 
337 LogicalResult spirv::CompositeConstructOp::verify() {
338  operand_range constituents = this->getConstituents();
339 
340  // There are 4 cases with varying verification rules:
341  // 1. Cooperative Matrices (1 constituent)
342  // 2. Structs (1 constituent for each member)
343  // 3. Arrays (1 constituent for each array element)
344  // 4. Vectors (1 constituent (sub-)element for each vector element)
345 
346  auto coopElementType =
349  [](auto coopType) { return coopType.getElementType(); })
350  .Default([](Type) { return nullptr; });
351 
352  // Case 1. -- matrices.
353  if (coopElementType) {
354  if (constituents.size() != 1)
355  return emitOpError("has incorrect number of operands: expected ")
356  << "1, but provided " << constituents.size();
357  if (coopElementType != constituents.front().getType())
358  return emitOpError("operand type mismatch: expected operand type ")
359  << coopElementType << ", but provided "
360  << constituents.front().getType();
361  return success();
362  }
363 
364  // Case 2./3./4. -- number of constituents matches the number of elements.
365  auto cType = llvm::cast<spirv::CompositeType>(getType());
366  if (constituents.size() == cType.getNumElements()) {
367  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
368  if (constituents[index].getType() != cType.getElementType(index)) {
369  return emitOpError("operand type mismatch: expected operand type ")
370  << cType.getElementType(index) << ", but provided "
371  << constituents[index].getType();
372  }
373  }
374  return success();
375  }
376 
377  // Case 4. -- check that all constituents add up tp the expected vector type.
378  auto resultType = llvm::dyn_cast<VectorType>(cType);
379  if (!resultType)
380  return emitOpError(
381  "expected to return a vector or cooperative matrix when the number of "
382  "constituents is less than what the result needs");
383 
384  SmallVector<unsigned> sizes;
385  for (Value component : constituents) {
386  if (!llvm::isa<VectorType>(component.getType()) &&
387  !component.getType().isIntOrFloat())
388  return emitOpError("operand type mismatch: expected operand to have "
389  "a scalar or vector type, but provided ")
390  << component.getType();
391 
392  Type elementType = component.getType();
393  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
394  sizes.push_back(vectorType.getNumElements());
395  elementType = vectorType.getElementType();
396  } else {
397  sizes.push_back(1);
398  }
399 
400  if (elementType != resultType.getElementType())
401  return emitOpError("operand element type mismatch: expected to be ")
402  << resultType.getElementType() << ", but provided " << elementType;
403  }
404  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
405  if (totalCount != cType.getNumElements())
406  return emitOpError("has incorrect number of operands: expected ")
407  << cType.getNumElements() << ", but provided " << totalCount;
408  return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // spirv.CompositeExtractOp
413 //===----------------------------------------------------------------------===//
414 
415 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
416  Value composite,
417  ArrayRef<int32_t> indices) {
418  auto indexAttr = builder.getI32ArrayAttr(indices);
419  auto elementType =
420  getElementType(composite.getType(), indexAttr, state.location);
421  if (!elementType) {
422  return;
423  }
424  build(builder, state, elementType, composite, indexAttr);
425 }
426 
428  OperationState &result) {
429  OpAsmParser::UnresolvedOperand compositeInfo;
430  Attribute indicesAttr;
431  StringRef indicesAttrName =
432  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
433  Type compositeType;
434  SMLoc attrLocation;
435 
436  if (parser.parseOperand(compositeInfo) ||
437  parser.getCurrentLocation(&attrLocation) ||
438  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
439  parser.parseColonType(compositeType) ||
440  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
441  return failure();
442  }
443 
444  Type resultType =
445  getElementType(compositeType, indicesAttr, parser, attrLocation);
446  if (!resultType) {
447  return failure();
448  }
449  result.addTypes(resultType);
450  return success();
451 }
452 
454  printer << ' ' << getComposite() << getIndices() << " : "
455  << getComposite().getType();
456 }
457 
458 LogicalResult spirv::CompositeExtractOp::verify() {
459  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
460  auto resultType =
461  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
462  if (!resultType)
463  return failure();
464 
465  if (resultType != getType()) {
466  return emitOpError("invalid result type: expected ")
467  << resultType << " but provided " << getType();
468  }
469 
470  return success();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // spirv.CompositeInsert
475 //===----------------------------------------------------------------------===//
476 
477 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
478  Value object, Value composite,
479  ArrayRef<int32_t> indices) {
480  auto indexAttr = builder.getI32ArrayAttr(indices);
481  build(builder, state, composite.getType(), object, composite, indexAttr);
482 }
483 
484 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
485  OperationState &result) {
487  Type objectType, compositeType;
488  Attribute indicesAttr;
489  StringRef indicesAttrName =
490  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
491  auto loc = parser.getCurrentLocation();
492 
493  return failure(
494  parser.parseOperandList(operands, 2) ||
495  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
496  parser.parseColonType(objectType) ||
497  parser.parseKeywordType("into", compositeType) ||
498  parser.resolveOperands(operands, {objectType, compositeType}, loc,
499  result.operands) ||
500  parser.addTypesToList(compositeType, result.types));
501 }
502 
503 LogicalResult spirv::CompositeInsertOp::verify() {
504  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
505  auto objectType =
506  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
507  if (!objectType)
508  return failure();
509 
510  if (objectType != getObject().getType()) {
511  return emitOpError("object operand type should be ")
512  << objectType << ", but found " << getObject().getType();
513  }
514 
515  if (getComposite().getType() != getType()) {
516  return emitOpError("result type should be the same as "
517  "the composite type, but found ")
518  << getComposite().getType() << " vs " << getType();
519  }
520 
521  return success();
522 }
523 
525  printer << " " << getObject() << ", " << getComposite() << getIndices()
526  << " : " << getObject().getType() << " into "
527  << getComposite().getType();
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // spirv.Constant
532 //===----------------------------------------------------------------------===//
533 
534 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
535  OperationState &result) {
536  Attribute value;
537  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
538  if (parser.parseAttribute(value, valueAttrName, result.attributes))
539  return failure();
540 
541  Type type = NoneType::get(parser.getContext());
542  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
543  type = typedAttr.getType();
544  if (llvm::isa<NoneType, TensorType>(type)) {
545  if (parser.parseColonType(type))
546  return failure();
547  }
548 
549  return parser.addTypeToList(type, result.types);
550 }
551 
552 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
553  printer << ' ' << getValue();
554  if (llvm::isa<spirv::ArrayType>(getType()))
555  printer << " : " << getType();
556 }
557 
558 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
559  Type opType) {
560  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
561  auto valueType = llvm::cast<TypedAttr>(value).getType();
562  if (valueType != opType)
563  return op.emitOpError("result type (")
564  << opType << ") does not match value type (" << valueType << ")";
565  return success();
566  }
567  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
568  auto valueType = llvm::cast<TypedAttr>(value).getType();
569  if (valueType == opType)
570  return success();
571  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
572  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
573  if (!arrayType)
574  return op.emitOpError("result or element type (")
575  << opType << ") does not match value type (" << valueType
576  << "), must be the same or spirv.array";
577 
578  int numElements = arrayType.getNumElements();
579  auto opElemType = arrayType.getElementType();
580  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
581  numElements *= t.getNumElements();
582  opElemType = t.getElementType();
583  }
584  if (!opElemType.isIntOrFloat())
585  return op.emitOpError("only support nested array result type");
586 
587  auto valueElemType = shapedType.getElementType();
588  if (valueElemType != opElemType) {
589  return op.emitOpError("result element type (")
590  << opElemType << ") does not match value element type ("
591  << valueElemType << ")";
592  }
593 
594  if (numElements != shapedType.getNumElements()) {
595  return op.emitOpError("result number of elements (")
596  << numElements << ") does not match value number of elements ("
597  << shapedType.getNumElements() << ")";
598  }
599  return success();
600  }
601  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
602  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
603  if (!arrayType)
604  return op.emitOpError(
605  "must have spirv.array result type for array value");
606  Type elemType = arrayType.getElementType();
607  for (Attribute element : arrayAttr.getValue()) {
608  // Verify array elements recursively.
609  if (failed(verifyConstantType(op, element, elemType)))
610  return failure();
611  }
612  return success();
613  }
614  return op.emitOpError("cannot have attribute: ") << value;
615 }
616 
617 LogicalResult spirv::ConstantOp::verify() {
618  // ODS already generates checks to make sure the result type is valid. We just
619  // need to additionally check that the value's attribute type is consistent
620  // with the result type.
621  return verifyConstantType(*this, getValueAttr(), getType());
622 }
623 
624 bool spirv::ConstantOp::isBuildableWith(Type type) {
625  // Must be valid SPIR-V type first.
626  if (!llvm::isa<spirv::SPIRVType>(type))
627  return false;
628 
629  if (isa<SPIRVDialect>(type.getDialect())) {
630  // TODO: support constant struct
631  return llvm::isa<spirv::ArrayType>(type);
632  }
633 
634  return true;
635 }
636 
637 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
638  OpBuilder &builder) {
639  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
640  unsigned width = intType.getWidth();
641  if (width == 1)
642  return builder.create<spirv::ConstantOp>(loc, type,
643  builder.getBoolAttr(false));
644  return builder.create<spirv::ConstantOp>(
645  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
646  }
647  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
648  return builder.create<spirv::ConstantOp>(
649  loc, type, builder.getFloatAttr(floatType, 0.0));
650  }
651  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
652  Type elemType = vectorType.getElementType();
653  if (llvm::isa<IntegerType>(elemType)) {
654  return builder.create<spirv::ConstantOp>(
655  loc, type,
656  DenseElementsAttr::get(vectorType,
657  IntegerAttr::get(elemType, 0).getValue()));
658  }
659  if (llvm::isa<FloatType>(elemType)) {
660  return builder.create<spirv::ConstantOp>(
661  loc, type,
662  DenseFPElementsAttr::get(vectorType,
663  FloatAttr::get(elemType, 0.0).getValue()));
664  }
665  }
666 
667  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
668 }
669 
670 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
671  OpBuilder &builder) {
672  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
673  unsigned width = intType.getWidth();
674  if (width == 1)
675  return builder.create<spirv::ConstantOp>(loc, type,
676  builder.getBoolAttr(true));
677  return builder.create<spirv::ConstantOp>(
678  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
679  }
680  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
681  return builder.create<spirv::ConstantOp>(
682  loc, type, builder.getFloatAttr(floatType, 1.0));
683  }
684  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
685  Type elemType = vectorType.getElementType();
686  if (llvm::isa<IntegerType>(elemType)) {
687  return builder.create<spirv::ConstantOp>(
688  loc, type,
689  DenseElementsAttr::get(vectorType,
690  IntegerAttr::get(elemType, 1).getValue()));
691  }
692  if (llvm::isa<FloatType>(elemType)) {
693  return builder.create<spirv::ConstantOp>(
694  loc, type,
695  DenseFPElementsAttr::get(vectorType,
696  FloatAttr::get(elemType, 1.0).getValue()));
697  }
698  }
699 
700  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
701 }
702 
703 void mlir::spirv::ConstantOp::getAsmResultNames(
704  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
705  Type type = getType();
706 
707  SmallString<32> specialNameBuffer;
708  llvm::raw_svector_ostream specialName(specialNameBuffer);
709  specialName << "cst";
710 
711  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
712 
713  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
714  if (intTy && intTy.getWidth() == 1) {
715  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
716  }
717 
718  if (intTy.isSignless()) {
719  specialName << intCst.getInt();
720  } else if (intTy.isUnsigned()) {
721  specialName << intCst.getUInt();
722  } else {
723  specialName << intCst.getSInt();
724  }
725  }
726 
727  if (intTy || llvm::isa<FloatType>(type)) {
728  specialName << '_' << type;
729  }
730 
731  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
732  specialName << "_vec_";
733  specialName << vecType.getDimSize(0);
734 
735  Type elementType = vecType.getElementType();
736 
737  if (llvm::isa<IntegerType>(elementType) ||
738  llvm::isa<FloatType>(elementType)) {
739  specialName << "x" << elementType;
740  }
741  }
742 
743  setNameFn(getResult(), specialName.str());
744 }
745 
746 void mlir::spirv::AddressOfOp::getAsmResultNames(
747  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
748  SmallString<32> specialNameBuffer;
749  llvm::raw_svector_ostream specialName(specialNameBuffer);
750  specialName << getVariable() << "_addr";
751  setNameFn(getResult(), specialName.str());
752 }
753 
754 //===----------------------------------------------------------------------===//
755 // spirv.ControlBarrierOp
756 //===----------------------------------------------------------------------===//
757 
758 LogicalResult spirv::ControlBarrierOp::verify() {
759  return verifyMemorySemantics(getOperation(), getMemorySemantics());
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // spirv.EntryPoint
764 //===----------------------------------------------------------------------===//
765 
766 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
767  spirv::ExecutionModel executionModel,
768  spirv::FuncOp function,
769  ArrayRef<Attribute> interfaceVars) {
770  build(builder, state,
771  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
772  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
773 }
774 
775 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
776  OperationState &result) {
777  spirv::ExecutionModel execModel;
779  SmallVector<Type, 0> idTypes;
780  SmallVector<Attribute, 4> interfaceVars;
781 
783  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
784  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
785  return failure();
786  }
787 
788  if (!parser.parseOptionalComma()) {
789  // Parse the interface variables
790  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
791  // The name of the interface variable attribute isnt important
792  FlatSymbolRefAttr var;
793  NamedAttrList attrs;
794  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
795  return failure();
796  interfaceVars.push_back(var);
797  return success();
798  }))
799  return failure();
800  }
801  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
802  parser.getBuilder().getArrayAttr(interfaceVars));
803  return success();
804 }
805 
807  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
808  printer.printSymbolName(getFn());
809  auto interfaceVars = getInterface().getValue();
810  if (!interfaceVars.empty()) {
811  printer << ", ";
812  llvm::interleaveComma(interfaceVars, printer);
813  }
814 }
815 
816 LogicalResult spirv::EntryPointOp::verify() {
817  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
818  // verification.
819  return success();
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // spirv.ExecutionMode
824 //===----------------------------------------------------------------------===//
825 
826 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
827  spirv::FuncOp function,
828  spirv::ExecutionMode executionMode,
829  ArrayRef<int32_t> params) {
830  build(builder, state, SymbolRefAttr::get(function),
831  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
832  builder.getI32ArrayAttr(params));
833 }
834 
835 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
836  OperationState &result) {
837  spirv::ExecutionMode execMode;
838  Attribute fn;
839  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
840  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
841  return failure();
842  }
843 
845  Type i32Type = parser.getBuilder().getIntegerType(32);
846  while (!parser.parseOptionalComma()) {
847  NamedAttrList attr;
848  Attribute value;
849  if (parser.parseAttribute(value, i32Type, "value", attr)) {
850  return failure();
851  }
852  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
853  }
854  StringRef valuesAttrName =
855  spirv::ExecutionModeOp::getValuesAttrName(result.name);
856  result.addAttribute(valuesAttrName,
857  parser.getBuilder().getI32ArrayAttr(values));
858  return success();
859 }
860 
862  printer << " ";
863  printer.printSymbolName(getFn());
864  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
865  auto values = this->getValues();
866  if (values.empty())
867  return;
868  printer << ", ";
869  llvm::interleaveComma(values, printer, [&](Attribute a) {
870  printer << llvm::cast<IntegerAttr>(a).getInt();
871  });
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // spirv.func
876 //===----------------------------------------------------------------------===//
877 
878 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
880  SmallVector<DictionaryAttr> resultAttrs;
881  SmallVector<Type> resultTypes;
882  auto &builder = parser.getBuilder();
883 
884  // Parse the name as a symbol.
885  StringAttr nameAttr;
886  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
887  result.attributes))
888  return failure();
889 
890  // Parse the function signature.
891  bool isVariadic = false;
893  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
894  resultAttrs))
895  return failure();
896 
897  SmallVector<Type> argTypes;
898  for (auto &arg : entryArgs)
899  argTypes.push_back(arg.type);
900  auto fnType = builder.getFunctionType(argTypes, resultTypes);
901  result.addAttribute(getFunctionTypeAttrName(result.name),
902  TypeAttr::get(fnType));
903 
904  // Parse the optional function control keyword.
905  spirv::FunctionControl fnControl;
906  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
907  return failure();
908 
909  // If additional attributes are present, parse them.
910  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
911  return failure();
912 
913  // Add the attributes to the function arguments.
914  assert(resultAttrs.size() == resultTypes.size());
916  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
917  getResAttrsAttrName(result.name));
918 
919  // Parse the optional function body.
920  auto *body = result.addRegion();
921  OptionalParseResult parseResult =
922  parser.parseOptionalRegion(*body, entryArgs);
923  return failure(parseResult.has_value() && failed(*parseResult));
924 }
925 
926 void spirv::FuncOp::print(OpAsmPrinter &printer) {
927  // Print function name, signature, and control.
928  printer << " ";
929  printer.printSymbolName(getSymName());
930  auto fnType = getFunctionType();
932  printer, *this, fnType.getInputs(),
933  /*isVariadic=*/false, fnType.getResults());
934  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
935  << "\"";
937  printer, *this,
938  {spirv::attributeName<spirv::FunctionControl>(),
939  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940  getFunctionControlAttrName()});
941 
942  // Print the body if this is not an external function.
943  Region &body = this->getBody();
944  if (!body.empty()) {
945  printer << ' ';
946  printer.printRegion(body, /*printEntryBlockArgs=*/false,
947  /*printBlockTerminators=*/true);
948  }
949 }
950 
951 LogicalResult spirv::FuncOp::verifyType() {
952  FunctionType fnType = getFunctionType();
953  if (fnType.getNumResults() > 1)
954  return emitOpError("cannot have more than one result");
955 
956  auto hasDecorationAttr = [&](spirv::Decoration decoration,
957  unsigned argIndex) {
958  auto func = llvm::cast<FunctionOpInterface>(getOperation());
959  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
960  if (argAttr.getName() != spirv::DecorationAttr::name)
961  continue;
962  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
963  return decAttr.getValue() == decoration;
964  }
965  return false;
966  };
967 
968  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969  Type param = fnType.getInputs()[i];
970  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
971  if (!inputPtrType)
972  continue;
973 
974  auto pointeePtrType =
975  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
976  if (pointeePtrType) {
977  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
978  // > If an OpFunctionParameter is a pointer (or contains a pointer)
979  // > and the type it points to is a pointer in the PhysicalStorageBuffer
980  // > storage class, the function parameter must be decorated with exactly
981  // > one of AliasedPointer or RestrictPointer.
982  if (pointeePtrType.getStorageClass() !=
983  spirv::StorageClass::PhysicalStorageBuffer)
984  continue;
985 
986  bool hasAliasedPtr =
987  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988  bool hasRestrictPtr =
989  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990  if (!hasAliasedPtr && !hasRestrictPtr)
991  return emitOpError()
992  << "with a pointer points to a physical buffer pointer must "
993  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
994  continue;
995  }
996  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
997  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
998  // > the PhysicalStorageBuffer storage class, the function parameter must
999  // > be decorated with exactly one of Aliased or Restrict.
1000  if (auto pointeeArrayType =
1001  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1002  pointeePtrType =
1003  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1004  } else {
1005  pointeePtrType = inputPtrType;
1006  }
1007 
1008  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009  spirv::StorageClass::PhysicalStorageBuffer)
1010  continue;
1011 
1012  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014  if (!hasAliased && !hasRestrict)
1015  return emitOpError() << "with physical buffer pointer must be decorated "
1016  "either 'Aliased' or 'Restrict'";
1017  }
1018 
1019  return success();
1020 }
1021 
1022 LogicalResult spirv::FuncOp::verifyBody() {
1023  FunctionType fnType = getFunctionType();
1024  if (!isExternal()) {
1025  Block &entryBlock = front();
1026 
1027  unsigned numArguments = this->getNumArguments();
1028  if (entryBlock.getNumArguments() != numArguments)
1029  return emitOpError("entry block must have ")
1030  << numArguments << " arguments to match function signature";
1031 
1032  for (auto [index, fnArgType, blockArgType] :
1033  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1034  if (blockArgType != fnArgType) {
1035  return emitOpError("type of entry block argument #")
1036  << index << '(' << blockArgType
1037  << ") must match the type of the corresponding argument in "
1038  << "function signature(" << fnArgType << ')';
1039  }
1040  }
1041  }
1042 
1043  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1044  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1045  if (fnType.getNumResults() != 0)
1046  return retOp.emitOpError("cannot be used in functions returning value");
1047  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1048  if (fnType.getNumResults() != 1)
1049  return retOp.emitOpError(
1050  "returns 1 value but enclosing function requires ")
1051  << fnType.getNumResults() << " results";
1052 
1053  auto retOperandType = retOp.getValue().getType();
1054  auto fnResultType = fnType.getResult(0);
1055  if (retOperandType != fnResultType)
1056  return retOp.emitOpError(" return value's type (")
1057  << retOperandType << ") mismatch with function's result type ("
1058  << fnResultType << ")";
1059  }
1060  return WalkResult::advance();
1061  });
1062 
1063  // TODO: verify other bits like linkage type.
1064 
1065  return failure(walkResult.wasInterrupted());
1066 }
1067 
1068 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1069  StringRef name, FunctionType type,
1070  spirv::FunctionControl control,
1071  ArrayRef<NamedAttribute> attrs) {
1072  state.addAttribute(SymbolTable::getSymbolAttrName(),
1073  builder.getStringAttr(name));
1074  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1075  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1076  builder.getAttr<spirv::FunctionControlAttr>(control));
1077  state.attributes.append(attrs.begin(), attrs.end());
1078  state.addRegion();
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // spirv.GLFClampOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1086  OperationState &result) {
1087  return parseOneResultSameOperandTypeOp(parser, result);
1088 }
1090 
1091 //===----------------------------------------------------------------------===//
1092 // spirv.GLUClampOp
1093 //===----------------------------------------------------------------------===//
1094 
1095 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1096  OperationState &result) {
1097  return parseOneResultSameOperandTypeOp(parser, result);
1098 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // spirv.GLSClampOp
1103 //===----------------------------------------------------------------------===//
1104 
1105 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1106  OperationState &result) {
1107  return parseOneResultSameOperandTypeOp(parser, result);
1108 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // spirv.GLFmaOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1116  return parseOneResultSameOperandTypeOp(parser, result);
1117 }
1118 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // spirv.GlobalVariable
1122 //===----------------------------------------------------------------------===//
1123 
1124 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1125  Type type, StringRef name,
1126  unsigned descriptorSet, unsigned binding) {
1127  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1128  state.addAttribute(
1129  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1130  builder.getI32IntegerAttr(descriptorSet));
1131  state.addAttribute(
1132  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1133  builder.getI32IntegerAttr(binding));
1134 }
1135 
1136 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1137  Type type, StringRef name,
1138  spirv::BuiltIn builtin) {
1139  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1140  state.addAttribute(
1141  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1142  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1143 }
1144 
1145 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1146  OperationState &result) {
1147  // Parse variable name.
1148  StringAttr nameAttr;
1149  StringRef initializerAttrName =
1150  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1151  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1152  result.attributes)) {
1153  return failure();
1154  }
1155 
1156  // Parse optional initializer
1157  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1158  FlatSymbolRefAttr initSymbol;
1159  if (parser.parseLParen() ||
1160  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1161  result.attributes) ||
1162  parser.parseRParen())
1163  return failure();
1164  }
1165 
1166  if (parseVariableDecorations(parser, result)) {
1167  return failure();
1168  }
1169 
1170  Type type;
1171  StringRef typeAttrName =
1172  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1173  auto loc = parser.getCurrentLocation();
1174  if (parser.parseColonType(type)) {
1175  return failure();
1176  }
1177  if (!llvm::isa<spirv::PointerType>(type)) {
1178  return parser.emitError(loc, "expected spirv.ptr type");
1179  }
1180  result.addAttribute(typeAttrName, TypeAttr::get(type));
1181 
1182  return success();
1183 }
1184 
1186  SmallVector<StringRef, 4> elidedAttrs{
1187  spirv::attributeName<spirv::StorageClass>()};
1188 
1189  // Print variable name.
1190  printer << ' ';
1191  printer.printSymbolName(getSymName());
1192  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1193 
1194  StringRef initializerAttrName = this->getInitializerAttrName();
1195  // Print optional initializer
1196  if (auto initializer = this->getInitializer()) {
1197  printer << " " << initializerAttrName << '(';
1198  printer.printSymbolName(*initializer);
1199  printer << ')';
1200  elidedAttrs.push_back(initializerAttrName);
1201  }
1202 
1203  StringRef typeAttrName = this->getTypeAttrName();
1204  elidedAttrs.push_back(typeAttrName);
1205  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1206  printer << " : " << getType();
1207 }
1208 
1209 LogicalResult spirv::GlobalVariableOp::verify() {
1210  if (!llvm::isa<spirv::PointerType>(getType()))
1211  return emitOpError("result must be of a !spv.ptr type");
1212 
1213  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1214  // object. It cannot be Generic. It must be the same as the Storage Class
1215  // operand of the Result Type."
1216  // Also, Function storage class is reserved by spirv.Variable.
1217  auto storageClass = this->storageClass();
1218  if (storageClass == spirv::StorageClass::Generic ||
1219  storageClass == spirv::StorageClass::Function) {
1220  return emitOpError("storage class cannot be '")
1221  << stringifyStorageClass(storageClass) << "'";
1222  }
1223 
1224  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1225  this->getInitializerAttrName())) {
1227  (*this)->getParentOp(), init.getAttr());
1228  // TODO: Currently only variable initialization with specialization
1229  // constants and other variables is supported. They could be normal
1230  // constants in the module scope as well.
1231  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1232  spirv::SpecConstantCompositeOp>(initOp)) {
1233  return emitOpError("initializer must be result of a "
1234  "spirv.SpecConstant or spirv.GlobalVariable or "
1235  "spirv.SpecConstantCompositeOp op");
1236  }
1237  }
1238 
1239  return success();
1240 }
1241 
1242 //===----------------------------------------------------------------------===//
1243 // spirv.INTEL.SubgroupBlockRead
1244 //===----------------------------------------------------------------------===//
1245 
1247  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1248  return failure();
1249 
1250  return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // spirv.INTEL.SubgroupBlockWrite
1255 //===----------------------------------------------------------------------===//
1256 
1258  OperationState &result) {
1259  // Parse the storage class specification
1260  spirv::StorageClass storageClass;
1262  auto loc = parser.getCurrentLocation();
1263  Type elementType;
1264  if (parseEnumStrAttr(storageClass, parser) ||
1265  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1266  parser.parseType(elementType)) {
1267  return failure();
1268  }
1269 
1270  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1271  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1272  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1273 
1274  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1275  result.operands)) {
1276  return failure();
1277  }
1278  return success();
1279 }
1280 
1282  printer << " " << getPtr() << ", " << getValue() << " : "
1283  << getValue().getType();
1284 }
1285 
1287  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1288  return failure();
1289 
1290  return success();
1291 }
1292 
1293 //===----------------------------------------------------------------------===//
1294 // spirv.IAddCarryOp
1295 //===----------------------------------------------------------------------===//
1296 
1297 LogicalResult spirv::IAddCarryOp::verify() {
1299 }
1300 
1301 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1302  OperationState &result) {
1304 }
1305 
1306 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1307  ::printArithmeticExtendedBinaryOp(*this, printer);
1308 }
1309 
1310 //===----------------------------------------------------------------------===//
1311 // spirv.ISubBorrowOp
1312 //===----------------------------------------------------------------------===//
1313 
1314 LogicalResult spirv::ISubBorrowOp::verify() {
1316 }
1317 
1318 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1319  OperationState &result) {
1321 }
1322 
1324  ::printArithmeticExtendedBinaryOp(*this, printer);
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // spirv.SMulExtended
1329 //===----------------------------------------------------------------------===//
1330 
1331 LogicalResult spirv::SMulExtendedOp::verify() {
1333 }
1334 
1335 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1336  OperationState &result) {
1338 }
1339 
1341  ::printArithmeticExtendedBinaryOp(*this, printer);
1342 }
1343 
1344 //===----------------------------------------------------------------------===//
1345 // spirv.UMulExtended
1346 //===----------------------------------------------------------------------===//
1347 
1348 LogicalResult spirv::UMulExtendedOp::verify() {
1350 }
1351 
1352 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1353  OperationState &result) {
1355 }
1356 
1358  ::printArithmeticExtendedBinaryOp(*this, printer);
1359 }
1360 
1361 //===----------------------------------------------------------------------===//
1362 // spirv.MemoryBarrierOp
1363 //===----------------------------------------------------------------------===//
1364 
1365 LogicalResult spirv::MemoryBarrierOp::verify() {
1366  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1367 }
1368 
1369 //===----------------------------------------------------------------------===//
1370 // spirv.module
1371 //===----------------------------------------------------------------------===//
1372 
1373 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1374  std::optional<StringRef> name) {
1375  OpBuilder::InsertionGuard guard(builder);
1376  builder.createBlock(state.addRegion());
1377  if (name) {
1378  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1379  builder.getStringAttr(*name));
1380  }
1381 }
1382 
1383 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1384  spirv::AddressingModel addressingModel,
1385  spirv::MemoryModel memoryModel,
1386  std::optional<VerCapExtAttr> vceTriple,
1387  std::optional<StringRef> name) {
1388  state.addAttribute(
1389  "addressing_model",
1390  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1391  state.addAttribute("memory_model",
1392  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1393  OpBuilder::InsertionGuard guard(builder);
1394  builder.createBlock(state.addRegion());
1395  if (vceTriple)
1396  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1397  if (name)
1398  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1399  builder.getStringAttr(*name));
1400 }
1401 
1402 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1403  OperationState &result) {
1404  Region *body = result.addRegion();
1405 
1406  // If the name is present, parse it.
1407  StringAttr nameAttr;
1408  (void)parser.parseOptionalSymbolName(
1409  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1410 
1411  // Parse attributes
1412  spirv::AddressingModel addrModel;
1413  spirv::MemoryModel memoryModel;
1414  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1415  result) ||
1416  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1417  result))
1418  return failure();
1419 
1420  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1421  spirv::VerCapExtAttr vceTriple;
1422  if (parser.parseAttribute(vceTriple,
1423  spirv::ModuleOp::getVCETripleAttrName(),
1424  result.attributes))
1425  return failure();
1426  }
1427 
1428  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1429  parser.parseRegion(*body, /*arguments=*/{}))
1430  return failure();
1431 
1432  // Make sure we have at least one block.
1433  if (body->empty())
1434  body->push_back(new Block());
1435 
1436  return success();
1437 }
1438 
1439 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1440  if (std::optional<StringRef> name = getName()) {
1441  printer << ' ';
1442  printer.printSymbolName(*name);
1443  }
1444 
1445  SmallVector<StringRef, 2> elidedAttrs;
1446 
1447  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1448  << spirv::stringifyMemoryModel(getMemoryModel());
1449  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1450  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1451  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1453 
1454  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1455  printer << " requires " << *triple;
1456  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1457  }
1458 
1459  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1460  printer << ' ';
1461  printer.printRegion(getRegion());
1462 }
1463 
1464 LogicalResult spirv::ModuleOp::verifyRegions() {
1465  Dialect *dialect = (*this)->getDialect();
1467  entryPoints;
1468  mlir::SymbolTable table(*this);
1469 
1470  for (auto &op : *getBody()) {
1471  if (op.getDialect() != dialect)
1472  return op.emitError("'spirv.module' can only contain spirv.* ops");
1473 
1474  // For EntryPoint op, check that the function and execution model is not
1475  // duplicated in EntryPointOps. Also verify that the interface specified
1476  // comes from globalVariables here to make this check cheaper.
1477  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1478  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1479  if (!funcOp) {
1480  return entryPointOp.emitError("function '")
1481  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1482  }
1483  if (auto interface = entryPointOp.getInterface()) {
1484  for (Attribute varRef : interface) {
1485  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1486  if (!varSymRef) {
1487  return entryPointOp.emitError(
1488  "expected symbol reference for interface "
1489  "specification instead of '")
1490  << varRef;
1491  }
1492  auto variableOp =
1493  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1494  if (!variableOp) {
1495  return entryPointOp.emitError("expected spirv.GlobalVariable "
1496  "symbol reference instead of'")
1497  << varSymRef << "'";
1498  }
1499  }
1500  }
1501 
1502  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1503  funcOp, entryPointOp.getExecutionModel());
1504  if (!entryPoints.try_emplace(key, entryPointOp).second)
1505  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1506  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1507  // If the function is external and does not have 'Import'
1508  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1509  // LinkageAttributes is used to import external functions.
1510  auto linkageAttr = funcOp.getLinkageAttributes();
1511  auto hasImportLinkage =
1512  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1513  spirv::LinkageType::Import);
1514  if (funcOp.isExternal() && !hasImportLinkage)
1515  return op.emitError(
1516  "'spirv.module' cannot contain external functions "
1517  "without 'Import' linkage_attributes (LinkageAttributes)");
1518 
1519  // TODO: move this check to spirv.func.
1520  for (auto &block : funcOp)
1521  for (auto &op : block) {
1522  if (op.getDialect() != dialect)
1523  return op.emitError(
1524  "functions in 'spirv.module' can only contain spirv.* ops");
1525  }
1526  }
1527  }
1528 
1529  return success();
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // spirv.mlir.referenceof
1534 //===----------------------------------------------------------------------===//
1535 
1536 LogicalResult spirv::ReferenceOfOp::verify() {
1537  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1538  (*this)->getParentOp(), getSpecConstAttr());
1539  Type constType;
1540 
1541  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1542  if (specConstOp)
1543  constType = specConstOp.getDefaultValue().getType();
1544 
1545  auto specConstCompositeOp =
1546  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1547  if (specConstCompositeOp)
1548  constType = specConstCompositeOp.getType();
1549 
1550  if (!specConstOp && !specConstCompositeOp)
1551  return emitOpError(
1552  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1553 
1554  if (getReference().getType() != constType)
1555  return emitOpError("result type mismatch with the referenced "
1556  "specialization constant's type");
1557 
1558  return success();
1559 }
1560 
1561 //===----------------------------------------------------------------------===//
1562 // spirv.SpecConstant
1563 //===----------------------------------------------------------------------===//
1564 
1565 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1566  OperationState &result) {
1567  StringAttr nameAttr;
1568  Attribute valueAttr;
1569  StringRef defaultValueAttrName =
1570  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1571 
1572  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1573  result.attributes))
1574  return failure();
1575 
1576  // Parse optional spec_id.
1577  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1578  IntegerAttr specIdAttr;
1579  if (parser.parseLParen() ||
1580  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1581  parser.parseRParen())
1582  return failure();
1583  }
1584 
1585  if (parser.parseEqual() ||
1586  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1587  return failure();
1588 
1589  return success();
1590 }
1591 
1593  printer << ' ';
1594  printer.printSymbolName(getSymName());
1595  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1596  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1597  printer << " = " << getDefaultValue();
1598 }
1599 
1600 LogicalResult spirv::SpecConstantOp::verify() {
1601  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1602  if (specID.getValue().isNegative())
1603  return emitOpError("SpecId cannot be negative");
1604 
1605  auto value = getDefaultValue();
1606  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1607  // Make sure bitwidth is allowed.
1608  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1609  return emitOpError("default value bitwidth disallowed");
1610  return success();
1611  }
1612  return emitOpError(
1613  "default value can only be a bool, integer, or float scalar");
1614 }
1615 
1616 //===----------------------------------------------------------------------===//
1617 // spirv.VectorShuffle
1618 //===----------------------------------------------------------------------===//
1619 
1620 LogicalResult spirv::VectorShuffleOp::verify() {
1621  VectorType resultType = llvm::cast<VectorType>(getType());
1622 
1623  size_t numResultElements = resultType.getNumElements();
1624  if (numResultElements != getComponents().size())
1625  return emitOpError("result type element count (")
1626  << numResultElements
1627  << ") mismatch with the number of component selectors ("
1628  << getComponents().size() << ")";
1629 
1630  size_t totalSrcElements =
1631  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1632  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1633 
1634  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1635  uint32_t index = selector.getZExtValue();
1636  if (index >= totalSrcElements &&
1637  index != std::numeric_limits<uint32_t>().max())
1638  return emitOpError("component selector ")
1639  << index << " out of range: expected to be in [0, "
1640  << totalSrcElements << ") or 0xffffffff";
1641  }
1642  return success();
1643 }
1644 
1645 //===----------------------------------------------------------------------===//
1646 // spirv.MatrixTimesScalar
1647 //===----------------------------------------------------------------------===//
1648 
1649 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1650  Type elementType =
1651  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1653  [](auto matrixType) { return matrixType.getElementType(); })
1654  .Default([](Type) { return nullptr; });
1655 
1656  assert(elementType && "Unhandled type");
1657 
1658  // Check that the scalar type is the same as the matrix element type.
1659  if (getScalar().getType() != elementType)
1660  return emitOpError("input matrix components' type and scaling value must "
1661  "have the same type");
1662 
1663  return success();
1664 }
1665 
1666 //===----------------------------------------------------------------------===//
1667 // spirv.Transpose
1668 //===----------------------------------------------------------------------===//
1669 
1670 LogicalResult spirv::TransposeOp::verify() {
1671  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1672  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1673 
1674  // Verify that the input and output matrices have correct shapes.
1675  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1676  return emitError("input matrix rows count must be equal to "
1677  "output matrix columns count");
1678 
1679  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1680  return emitError("input matrix columns count must be equal to "
1681  "output matrix rows count");
1682 
1683  // Verify that the input and output matrices have the same component type
1684  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1685  return emitError("input and output matrices must have the same "
1686  "component type");
1687 
1688  return success();
1689 }
1690 
1691 //===----------------------------------------------------------------------===//
1692 // spirv.MatrixTimesVector
1693 //===----------------------------------------------------------------------===//
1694 
1695 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1696  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1697  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1698  auto resultType = llvm::cast<VectorType>(getType());
1699 
1700  if (matrixType.getNumColumns() != vectorType.getNumElements())
1701  return emitOpError("matrix columns (")
1702  << matrixType.getNumColumns() << ") must match vector operand size ("
1703  << vectorType.getNumElements() << ")";
1704 
1705  if (resultType.getNumElements() != matrixType.getNumRows())
1706  return emitOpError("result size (")
1707  << resultType.getNumElements() << ") must match the matrix rows ("
1708  << matrixType.getNumRows() << ")";
1709 
1710  if (matrixType.getElementType() != resultType.getElementType())
1711  return emitOpError("matrix and result element types must match");
1712 
1713  return success();
1714 }
1715 
1716 //===----------------------------------------------------------------------===//
1717 // spirv.VectorTimesMatrix
1718 //===----------------------------------------------------------------------===//
1719 
1720 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1721  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1722  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1723  auto resultType = llvm::cast<VectorType>(getType());
1724 
1725  if (matrixType.getNumRows() != vectorType.getNumElements())
1726  return emitOpError("number of components in vector must equal the number "
1727  "of components in each column in matrix");
1728 
1729  if (resultType.getNumElements() != matrixType.getNumColumns())
1730  return emitOpError("number of columns in matrix must equal the number of "
1731  "components in result");
1732 
1733  if (matrixType.getElementType() != resultType.getElementType())
1734  return emitOpError("matrix must be a matrix with the same component type "
1735  "as the component type in result");
1736 
1737  return success();
1738 }
1739 
1740 //===----------------------------------------------------------------------===//
1741 // spirv.MatrixTimesMatrix
1742 //===----------------------------------------------------------------------===//
1743 
1744 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1745  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1746  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1747  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1748 
1749  // left matrix columns' count and right matrix rows' count must be equal
1750  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1751  return emitError("left matrix columns' count must be equal to "
1752  "the right matrix rows' count");
1753 
1754  // right and result matrices columns' count must be the same
1755  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1756  return emitError(
1757  "right and result matrices must have equal columns' count");
1758 
1759  // right and result matrices component type must be the same
1760  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1761  return emitError("right and result matrices' component type must"
1762  " be the same");
1763 
1764  // left and result matrices component type must be the same
1765  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1766  return emitError("left and result matrices' component type"
1767  " must be the same");
1768 
1769  // left and result matrices rows count must be the same
1770  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1771  return emitError("left and result matrices must have equal rows' count");
1772 
1773  return success();
1774 }
1775 
1776 //===----------------------------------------------------------------------===//
1777 // spirv.SpecConstantComposite
1778 //===----------------------------------------------------------------------===//
1779 
1781  OperationState &result) {
1782 
1783  StringAttr compositeName;
1784  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1785  result.attributes))
1786  return failure();
1787 
1788  if (parser.parseLParen())
1789  return failure();
1790 
1791  SmallVector<Attribute, 4> constituents;
1792 
1793  do {
1794  // The name of the constituent attribute isn't important
1795  const char *attrName = "spec_const";
1796  FlatSymbolRefAttr specConstRef;
1797  NamedAttrList attrs;
1798 
1799  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1800  return failure();
1801 
1802  constituents.push_back(specConstRef);
1803  } while (!parser.parseOptionalComma());
1804 
1805  if (parser.parseRParen())
1806  return failure();
1807 
1808  StringAttr compositeSpecConstituentsName =
1809  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1810  result.addAttribute(compositeSpecConstituentsName,
1811  parser.getBuilder().getArrayAttr(constituents));
1812 
1813  Type type;
1814  if (parser.parseColonType(type))
1815  return failure();
1816 
1817  StringAttr typeAttrName =
1818  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1819  result.addAttribute(typeAttrName, TypeAttr::get(type));
1820 
1821  return success();
1822 }
1823 
1825  printer << " ";
1826  printer.printSymbolName(getSymName());
1827  printer << " (";
1828  auto constituents = this->getConstituents().getValue();
1829 
1830  if (!constituents.empty())
1831  llvm::interleaveComma(constituents, printer);
1832 
1833  printer << ") : " << getType();
1834 }
1835 
1837  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1838  auto constituents = this->getConstituents().getValue();
1839 
1840  if (!cType)
1841  return emitError("result type must be a composite type, but provided ")
1842  << getType();
1843 
1844  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1845  return emitError("unsupported composite type ") << cType;
1846  if (constituents.size() != cType.getNumElements())
1847  return emitError("has incorrect number of operands: expected ")
1848  << cType.getNumElements() << ", but provided "
1849  << constituents.size();
1850 
1851  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1852  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1853 
1854  auto constituentSpecConstOp =
1855  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1856  (*this)->getParentOp(), constituent.getAttr()));
1857 
1858  if (constituentSpecConstOp.getDefaultValue().getType() !=
1859  cType.getElementType(index))
1860  return emitError("has incorrect types of operands: expected ")
1861  << cType.getElementType(index) << ", but provided "
1862  << constituentSpecConstOp.getDefaultValue().getType();
1863  }
1864 
1865  return success();
1866 }
1867 
1868 //===----------------------------------------------------------------------===//
1869 // spirv.SpecConstantOperation
1870 //===----------------------------------------------------------------------===//
1871 
1873  OperationState &result) {
1874  Region *body = result.addRegion();
1875 
1876  if (parser.parseKeyword("wraps"))
1877  return failure();
1878 
1879  body->push_back(new Block);
1880  Block &block = body->back();
1881  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1882 
1883  if (!wrappedOp)
1884  return failure();
1885 
1886  OpBuilder builder(parser.getContext());
1887  builder.setInsertionPointToEnd(&block);
1888  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1889  result.location = wrappedOp->getLoc();
1890 
1891  result.addTypes(wrappedOp->getResult(0).getType());
1892 
1893  if (parser.parseOptionalAttrDict(result.attributes))
1894  return failure();
1895 
1896  return success();
1897 }
1898 
1900  printer << " wraps ";
1901  printer.printGenericOp(&getBody().front().front());
1902 }
1903 
1904 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1905  Block &block = getRegion().getBlocks().front();
1906 
1907  if (block.getOperations().size() != 2)
1908  return emitOpError("expected exactly 2 nested ops");
1909 
1910  Operation &enclosedOp = block.getOperations().front();
1911 
1913  return emitOpError("invalid enclosed op");
1914 
1915  for (auto operand : enclosedOp.getOperands())
1916  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1917  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1918  return emitOpError(
1919  "invalid operand, must be defined by a constant operation");
1920 
1921  return success();
1922 }
1923 
1924 //===----------------------------------------------------------------------===//
1925 // spirv.GL.FrexpStruct
1926 //===----------------------------------------------------------------------===//
1927 
1928 LogicalResult spirv::GLFrexpStructOp::verify() {
1929  spirv::StructType structTy =
1930  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1931 
1932  if (structTy.getNumElements() != 2)
1933  return emitError("result type must be a struct type with two memebers");
1934 
1935  Type significandTy = structTy.getElementType(0);
1936  Type exponentTy = structTy.getElementType(1);
1937  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1938  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1939 
1940  Type operandTy = getOperand().getType();
1941  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1942  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1943 
1944  if (significandTy != operandTy)
1945  return emitError("member zero of the resulting struct type must be the "
1946  "same type as the operand");
1947 
1948  if (exponentVecTy) {
1949  IntegerType componentIntTy =
1950  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1951  if (!componentIntTy || componentIntTy.getWidth() != 32)
1952  return emitError("member one of the resulting struct type must"
1953  "be a scalar or vector of 32 bit integer type");
1954  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1955  return emitError("member one of the resulting struct type "
1956  "must be a scalar or vector of 32 bit integer type");
1957  }
1958 
1959  // Check that the two member types have the same number of components
1960  if (operandVecTy && exponentVecTy &&
1961  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1962  return success();
1963 
1964  if (operandFTy && exponentIntTy)
1965  return success();
1966 
1967  return emitError("member one of the resulting struct type must have the same "
1968  "number of components as the operand type");
1969 }
1970 
1971 //===----------------------------------------------------------------------===//
1972 // spirv.GL.Ldexp
1973 //===----------------------------------------------------------------------===//
1974 
1975 LogicalResult spirv::GLLdexpOp::verify() {
1976  Type significandType = getX().getType();
1977  Type exponentType = getExp().getType();
1978 
1979  if (llvm::isa<FloatType>(significandType) !=
1980  llvm::isa<IntegerType>(exponentType))
1981  return emitOpError("operands must both be scalars or vectors");
1982 
1983  auto getNumElements = [](Type type) -> unsigned {
1984  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1985  return vectorType.getNumElements();
1986  return 1;
1987  };
1988 
1989  if (getNumElements(significandType) != getNumElements(exponentType))
1990  return emitOpError("operands must have the same number of elements");
1991 
1992  return success();
1993 }
1994 
1995 //===----------------------------------------------------------------------===//
1996 // spirv.ShiftLeftLogicalOp
1997 //===----------------------------------------------------------------------===//
1998 
1999 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2000  return verifyShiftOp(*this);
2001 }
2002 
2003 //===----------------------------------------------------------------------===//
2004 // spirv.ShiftRightArithmeticOp
2005 //===----------------------------------------------------------------------===//
2006 
2007 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2008  return verifyShiftOp(*this);
2009 }
2010 
2011 //===----------------------------------------------------------------------===//
2012 // spirv.ShiftRightLogicalOp
2013 //===----------------------------------------------------------------------===//
2014 
2015 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2016  return verifyShiftOp(*this);
2017 }
2018 
2019 //===----------------------------------------------------------------------===//
2020 // spirv.VectorTimesScalarOp
2021 //===----------------------------------------------------------------------===//
2022 
2023 LogicalResult spirv::VectorTimesScalarOp::verify() {
2024  if (getVector().getType() != getType())
2025  return emitOpError("vector operand and result type mismatch");
2026  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2027  if (getScalar().getType() != scalarType)
2028  return emitOpError("scalar operand and result element type match");
2029  return success();
2030 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:268
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:558
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:121
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:254
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:300
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:170
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:150
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:292
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:272
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
SPIR-V struct type.
Definition: SPIRVTypes.h:293
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:94
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:50
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.