MLIR  21.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/InterleavedRange.h"
39 #include <cassert>
40 #include <numeric>
41 #include <optional>
42 #include <type_traits>
43 
44 using namespace mlir;
45 using namespace mlir::spirv::AttrNames;
46 
47 //===----------------------------------------------------------------------===//
48 // Common utility functions
49 //===----------------------------------------------------------------------===//
50 
51 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
52  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53  if (!constOp) {
54  return failure();
55  }
56  auto valueAttr = constOp.getValue();
57  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58  if (!integerValueAttr) {
59  return failure();
60  }
61 
62  if (integerValueAttr.getType().isSignlessInteger())
63  value = integerValueAttr.getInt();
64  else
65  value = integerValueAttr.getSInt();
66 
67  return success();
68 }
69 
70 LogicalResult
72  spirv::MemorySemantics memorySemantics) {
73  // According to the SPIR-V specification:
74  // "Despite being a mask and allowing multiple bits to be combined, it is
75  // invalid for more than one of these four bits to be set: Acquire, Release,
76  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77  // Release semantics is done by setting the AcquireRelease bit, not by setting
78  // two bits."
79  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80  spirv::MemorySemantics::Release |
81  spirv::MemorySemantics::AcquireRelease |
82  spirv::MemorySemantics::SequentiallyConsistent;
83 
84  auto bitCount =
85  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86  if (bitCount > 1) {
87  return op->emitError(
88  "expected at most one of these four memory constraints "
89  "to be set: `Acquire`, `Release`,"
90  "`AcquireRelease` or `SequentiallyConsistent`");
91  }
92  return success();
93 }
94 
96  SmallVectorImpl<StringRef> &elidedAttrs) {
97  // Print optional descriptor binding
98  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::DescriptorSet));
100  auto bindingName = llvm::convertToSnakeFromCamelCase(
101  stringifyDecoration(spirv::Decoration::Binding));
102  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104  if (descriptorSet && binding) {
105  elidedAttrs.push_back(descriptorSetName);
106  elidedAttrs.push_back(bindingName);
107  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108  << ")";
109  }
110 
111  // Print BuiltIn attribute if present
112  auto builtInName = llvm::convertToSnakeFromCamelCase(
113  stringifyDecoration(spirv::Decoration::BuiltIn));
114  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116  elidedAttrs.push_back(builtInName);
117  }
118 
119  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
120 }
121 
122 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
123  OperationState &result) {
125  Type type;
126  // If the operand list is in-between parentheses, then we have a generic form.
127  // (see the fallback in `printOneResultOp`).
128  SMLoc loc = parser.getCurrentLocation();
129  if (!parser.parseOptionalLParen()) {
130  if (parser.parseOperandList(ops) || parser.parseRParen() ||
131  parser.parseOptionalAttrDict(result.attributes) ||
132  parser.parseColon() || parser.parseType(type))
133  return failure();
134  auto fnType = llvm::dyn_cast<FunctionType>(type);
135  if (!fnType) {
136  parser.emitError(loc, "expected function type");
137  return failure();
138  }
139  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140  return failure();
141  result.addTypes(fnType.getResults());
142  return success();
143  }
144  return failure(parser.parseOperandList(ops) ||
145  parser.parseOptionalAttrDict(result.attributes) ||
146  parser.parseColonType(type) ||
147  parser.resolveOperands(ops, type, result.operands) ||
148  parser.addTypeToList(type, result.types));
149 }
150 
152  assert(op->getNumResults() == 1 && "op should have one result");
153 
154  // If not all the operand and result types are the same, just use the
155  // generic assembly form to avoid omitting information in printing.
156  auto resultType = op->getResult(0).getType();
157  if (llvm::any_of(op->getOperandTypes(),
158  [&](Type type) { return type != resultType; })) {
159  p.printGenericOp(op, /*printOpName=*/false);
160  return;
161  }
162 
163  p << ' ';
164  p.printOperands(op->getOperands());
166  // Now we can output only one type for all operands and the result.
167  p << " : " << resultType;
168 }
169 
170 template <typename BlockReadWriteOpTy>
171 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
172  Value ptr, Value val) {
173  auto valType = val.getType();
174  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
175  valType = valVecTy.getElementType();
176 
177  if (valType !=
178  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
179  return op.emitOpError("mismatch in result type and pointer type");
180  }
181  return success();
182 }
183 
184 /// Walks the given type hierarchy with the given indices, potentially down
185 /// to component granularity, to select an element type. Returns null type and
186 /// emits errors with the given loc on failure.
187 static Type
189  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
190  if (indices.empty()) {
191  emitErrorFn("expected at least one index for spirv.CompositeExtract");
192  return nullptr;
193  }
194 
195  for (auto index : indices) {
196  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
197  if (cType.hasCompileTimeKnownNumElements() &&
198  (index < 0 ||
199  static_cast<uint64_t>(index) >= cType.getNumElements())) {
200  emitErrorFn("index ") << index << " out of bounds for " << type;
201  return nullptr;
202  }
203  type = cType.getElementType(index);
204  } else {
205  emitErrorFn("cannot extract from non-composite type ")
206  << type << " with index " << index;
207  return nullptr;
208  }
209  }
210  return type;
211 }
212 
213 static Type
215  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
216  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
217  if (!indicesArrayAttr) {
218  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
219  return nullptr;
220  }
221  if (indicesArrayAttr.empty()) {
222  emitErrorFn("expected at least one index for spirv.CompositeExtract");
223  return nullptr;
224  }
225 
226  SmallVector<int32_t, 2> indexVals;
227  for (auto indexAttr : indicesArrayAttr) {
228  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
229  if (!indexIntAttr) {
230  emitErrorFn("expected an 32-bit integer for index, but found '")
231  << indexAttr << "'";
232  return nullptr;
233  }
234  indexVals.push_back(indexIntAttr.getInt());
235  }
236  return getElementType(type, indexVals, emitErrorFn);
237 }
238 
239 static Type getElementType(Type type, Attribute indices, Location loc) {
240  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
241  return ::mlir::emitError(loc, err);
242  };
243  return getElementType(type, indices, errorFn);
244 }
245 
246 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
247  SMLoc loc) {
248  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
249  return parser.emitError(loc, err);
250  };
251  return getElementType(type, indices, errorFn);
252 }
253 
254 template <typename ExtendedBinaryOp>
255 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
256  auto resultType = llvm::cast<spirv::StructType>(op.getType());
257  if (resultType.getNumElements() != 2)
258  return op.emitOpError("expected result struct type containing two members");
259 
260  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
261  resultType.getElementType(0),
262  resultType.getElementType(1)}))
263  return op.emitOpError(
264  "expected all operand types and struct member types are the same");
265 
266  return success();
267 }
268 
269 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
270  OperationState &result) {
272  if (parser.parseOptionalAttrDict(result.attributes) ||
273  parser.parseOperandList(operands) || parser.parseColon())
274  return failure();
275 
276  Type resultType;
277  SMLoc loc = parser.getCurrentLocation();
278  if (parser.parseType(resultType))
279  return failure();
280 
281  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
282  if (!structType || structType.getNumElements() != 2)
283  return parser.emitError(loc, "expected spirv.struct type with two members");
284 
285  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
286  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
287  return failure();
288 
289  result.addTypes(resultType);
290  return success();
291 }
292 
294  OpAsmPrinter &printer) {
295  printer << ' ';
296  printer.printOptionalAttrDict(op->getAttrs());
297  printer.printOperands(op->getOperands());
298  printer << " : " << op->getResultTypes().front();
299 }
300 
301 static LogicalResult verifyShiftOp(Operation *op) {
302  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
303  return op->emitError("expected the same type for the first operand and "
304  "result, but provided ")
305  << op->getOperand(0).getType() << " and "
306  << op->getResult(0).getType();
307  }
308  return success();
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // spirv.mlir.addressof
313 //===----------------------------------------------------------------------===//
314 
315 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
316  spirv::GlobalVariableOp var) {
317  build(builder, state, var.getType(), SymbolRefAttr::get(var));
318 }
319 
320 LogicalResult spirv::AddressOfOp::verify() {
321  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
322  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
323  getVariableAttr()));
324  if (!varOp) {
325  return emitOpError("expected spirv.GlobalVariable symbol");
326  }
327  if (getPointer().getType() != varOp.getType()) {
328  return emitOpError(
329  "result type mismatch with the referenced global variable's type");
330  }
331  return success();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // spirv.CompositeConstruct
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult spirv::CompositeConstructOp::verify() {
339  operand_range constituents = this->getConstituents();
340 
341  // There are 4 cases with varying verification rules:
342  // 1. Cooperative Matrices (1 constituent)
343  // 2. Structs (1 constituent for each member)
344  // 3. Arrays (1 constituent for each array element)
345  // 4. Vectors (1 constituent (sub-)element for each vector element)
346 
347  auto coopElementType =
350  [](auto coopType) { return coopType.getElementType(); })
351  .Default([](Type) { return nullptr; });
352 
353  // Case 1. -- matrices.
354  if (coopElementType) {
355  if (constituents.size() != 1)
356  return emitOpError("has incorrect number of operands: expected ")
357  << "1, but provided " << constituents.size();
358  if (coopElementType != constituents.front().getType())
359  return emitOpError("operand type mismatch: expected operand type ")
360  << coopElementType << ", but provided "
361  << constituents.front().getType();
362  return success();
363  }
364 
365  // Case 2./3./4. -- number of constituents matches the number of elements.
366  auto cType = llvm::cast<spirv::CompositeType>(getType());
367  if (constituents.size() == cType.getNumElements()) {
368  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
369  if (constituents[index].getType() != cType.getElementType(index)) {
370  return emitOpError("operand type mismatch: expected operand type ")
371  << cType.getElementType(index) << ", but provided "
372  << constituents[index].getType();
373  }
374  }
375  return success();
376  }
377 
378  // Case 4. -- check that all constituents add up tp the expected vector type.
379  auto resultType = llvm::dyn_cast<VectorType>(cType);
380  if (!resultType)
381  return emitOpError(
382  "expected to return a vector or cooperative matrix when the number of "
383  "constituents is less than what the result needs");
384 
385  SmallVector<unsigned> sizes;
386  for (Value component : constituents) {
387  if (!llvm::isa<VectorType>(component.getType()) &&
388  !component.getType().isIntOrFloat())
389  return emitOpError("operand type mismatch: expected operand to have "
390  "a scalar or vector type, but provided ")
391  << component.getType();
392 
393  Type elementType = component.getType();
394  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
395  sizes.push_back(vectorType.getNumElements());
396  elementType = vectorType.getElementType();
397  } else {
398  sizes.push_back(1);
399  }
400 
401  if (elementType != resultType.getElementType())
402  return emitOpError("operand element type mismatch: expected to be ")
403  << resultType.getElementType() << ", but provided " << elementType;
404  }
405  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
406  if (totalCount != cType.getNumElements())
407  return emitOpError("has incorrect number of operands: expected ")
408  << cType.getNumElements() << ", but provided " << totalCount;
409  return success();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // spirv.CompositeExtractOp
414 //===----------------------------------------------------------------------===//
415 
416 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
417  Value composite,
418  ArrayRef<int32_t> indices) {
419  auto indexAttr = builder.getI32ArrayAttr(indices);
420  auto elementType =
421  getElementType(composite.getType(), indexAttr, state.location);
422  if (!elementType) {
423  return;
424  }
425  build(builder, state, elementType, composite, indexAttr);
426 }
427 
429  OperationState &result) {
430  OpAsmParser::UnresolvedOperand compositeInfo;
431  Attribute indicesAttr;
432  StringRef indicesAttrName =
433  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
434  Type compositeType;
435  SMLoc attrLocation;
436 
437  if (parser.parseOperand(compositeInfo) ||
438  parser.getCurrentLocation(&attrLocation) ||
439  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
440  parser.parseColonType(compositeType) ||
441  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
442  return failure();
443  }
444 
445  Type resultType =
446  getElementType(compositeType, indicesAttr, parser, attrLocation);
447  if (!resultType) {
448  return failure();
449  }
450  result.addTypes(resultType);
451  return success();
452 }
453 
455  printer << ' ' << getComposite() << getIndices() << " : "
456  << getComposite().getType();
457 }
458 
459 LogicalResult spirv::CompositeExtractOp::verify() {
460  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
461  auto resultType =
462  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
463  if (!resultType)
464  return failure();
465 
466  if (resultType != getType()) {
467  return emitOpError("invalid result type: expected ")
468  << resultType << " but provided " << getType();
469  }
470 
471  return success();
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // spirv.CompositeInsert
476 //===----------------------------------------------------------------------===//
477 
478 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
479  Value object, Value composite,
480  ArrayRef<int32_t> indices) {
481  auto indexAttr = builder.getI32ArrayAttr(indices);
482  build(builder, state, composite.getType(), object, composite, indexAttr);
483 }
484 
485 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
486  OperationState &result) {
488  Type objectType, compositeType;
489  Attribute indicesAttr;
490  StringRef indicesAttrName =
491  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
492  auto loc = parser.getCurrentLocation();
493 
494  return failure(
495  parser.parseOperandList(operands, 2) ||
496  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
497  parser.parseColonType(objectType) ||
498  parser.parseKeywordType("into", compositeType) ||
499  parser.resolveOperands(operands, {objectType, compositeType}, loc,
500  result.operands) ||
501  parser.addTypesToList(compositeType, result.types));
502 }
503 
504 LogicalResult spirv::CompositeInsertOp::verify() {
505  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
506  auto objectType =
507  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
508  if (!objectType)
509  return failure();
510 
511  if (objectType != getObject().getType()) {
512  return emitOpError("object operand type should be ")
513  << objectType << ", but found " << getObject().getType();
514  }
515 
516  if (getComposite().getType() != getType()) {
517  return emitOpError("result type should be the same as "
518  "the composite type, but found ")
519  << getComposite().getType() << " vs " << getType();
520  }
521 
522  return success();
523 }
524 
526  printer << " " << getObject() << ", " << getComposite() << getIndices()
527  << " : " << getObject().getType() << " into "
528  << getComposite().getType();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // spirv.Constant
533 //===----------------------------------------------------------------------===//
534 
535 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
536  OperationState &result) {
537  Attribute value;
538  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
539  if (parser.parseAttribute(value, valueAttrName, result.attributes))
540  return failure();
541 
542  Type type = NoneType::get(parser.getContext());
543  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
544  type = typedAttr.getType();
545  if (llvm::isa<NoneType, TensorType>(type)) {
546  if (parser.parseColonType(type))
547  return failure();
548  }
549 
550  return parser.addTypeToList(type, result.types);
551 }
552 
553 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
554  printer << ' ' << getValue();
555  if (llvm::isa<spirv::ArrayType>(getType()))
556  printer << " : " << getType();
557 }
558 
559 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
560  Type opType) {
561  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
562  auto valueType = llvm::cast<TypedAttr>(value).getType();
563  if (valueType != opType)
564  return op.emitOpError("result type (")
565  << opType << ") does not match value type (" << valueType << ")";
566  return success();
567  }
568  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
569  auto valueType = llvm::cast<TypedAttr>(value).getType();
570  if (valueType == opType)
571  return success();
572  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
573  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
574  if (!arrayType)
575  return op.emitOpError("result or element type (")
576  << opType << ") does not match value type (" << valueType
577  << "), must be the same or spirv.array";
578 
579  int numElements = arrayType.getNumElements();
580  auto opElemType = arrayType.getElementType();
581  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
582  numElements *= t.getNumElements();
583  opElemType = t.getElementType();
584  }
585  if (!opElemType.isIntOrFloat())
586  return op.emitOpError("only support nested array result type");
587 
588  auto valueElemType = shapedType.getElementType();
589  if (valueElemType != opElemType) {
590  return op.emitOpError("result element type (")
591  << opElemType << ") does not match value element type ("
592  << valueElemType << ")";
593  }
594 
595  if (numElements != shapedType.getNumElements()) {
596  return op.emitOpError("result number of elements (")
597  << numElements << ") does not match value number of elements ("
598  << shapedType.getNumElements() << ")";
599  }
600  return success();
601  }
602  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
603  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
604  if (!arrayType)
605  return op.emitOpError(
606  "must have spirv.array result type for array value");
607  Type elemType = arrayType.getElementType();
608  for (Attribute element : arrayAttr.getValue()) {
609  // Verify array elements recursively.
610  if (failed(verifyConstantType(op, element, elemType)))
611  return failure();
612  }
613  return success();
614  }
615  return op.emitOpError("cannot have attribute: ") << value;
616 }
617 
618 LogicalResult spirv::ConstantOp::verify() {
619  // ODS already generates checks to make sure the result type is valid. We just
620  // need to additionally check that the value's attribute type is consistent
621  // with the result type.
622  return verifyConstantType(*this, getValueAttr(), getType());
623 }
624 
625 bool spirv::ConstantOp::isBuildableWith(Type type) {
626  // Must be valid SPIR-V type first.
627  if (!llvm::isa<spirv::SPIRVType>(type))
628  return false;
629 
630  if (isa<SPIRVDialect>(type.getDialect())) {
631  // TODO: support constant struct
632  return llvm::isa<spirv::ArrayType>(type);
633  }
634 
635  return true;
636 }
637 
638 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
639  OpBuilder &builder) {
640  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
641  unsigned width = intType.getWidth();
642  if (width == 1)
643  return builder.create<spirv::ConstantOp>(loc, type,
644  builder.getBoolAttr(false));
645  return builder.create<spirv::ConstantOp>(
646  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
647  }
648  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
649  return builder.create<spirv::ConstantOp>(
650  loc, type, builder.getFloatAttr(floatType, 0.0));
651  }
652  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
653  Type elemType = vectorType.getElementType();
654  if (llvm::isa<IntegerType>(elemType)) {
655  return builder.create<spirv::ConstantOp>(
656  loc, type,
657  DenseElementsAttr::get(vectorType,
658  IntegerAttr::get(elemType, 0).getValue()));
659  }
660  if (llvm::isa<FloatType>(elemType)) {
661  return builder.create<spirv::ConstantOp>(
662  loc, type,
663  DenseFPElementsAttr::get(vectorType,
664  FloatAttr::get(elemType, 0.0).getValue()));
665  }
666  }
667 
668  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
669 }
670 
671 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
672  OpBuilder &builder) {
673  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
674  unsigned width = intType.getWidth();
675  if (width == 1)
676  return builder.create<spirv::ConstantOp>(loc, type,
677  builder.getBoolAttr(true));
678  return builder.create<spirv::ConstantOp>(
679  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
680  }
681  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
682  return builder.create<spirv::ConstantOp>(
683  loc, type, builder.getFloatAttr(floatType, 1.0));
684  }
685  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
686  Type elemType = vectorType.getElementType();
687  if (llvm::isa<IntegerType>(elemType)) {
688  return builder.create<spirv::ConstantOp>(
689  loc, type,
690  DenseElementsAttr::get(vectorType,
691  IntegerAttr::get(elemType, 1).getValue()));
692  }
693  if (llvm::isa<FloatType>(elemType)) {
694  return builder.create<spirv::ConstantOp>(
695  loc, type,
696  DenseFPElementsAttr::get(vectorType,
697  FloatAttr::get(elemType, 1.0).getValue()));
698  }
699  }
700 
701  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
702 }
703 
704 void mlir::spirv::ConstantOp::getAsmResultNames(
705  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
706  Type type = getType();
707 
708  SmallString<32> specialNameBuffer;
709  llvm::raw_svector_ostream specialName(specialNameBuffer);
710  specialName << "cst";
711 
712  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
713 
714  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
715  if (intTy && intTy.getWidth() == 1) {
716  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
717  }
718 
719  if (intTy.isSignless()) {
720  specialName << intCst.getInt();
721  } else if (intTy.isUnsigned()) {
722  specialName << intCst.getUInt();
723  } else {
724  specialName << intCst.getSInt();
725  }
726  }
727 
728  if (intTy || llvm::isa<FloatType>(type)) {
729  specialName << '_' << type;
730  }
731 
732  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
733  specialName << "_vec_";
734  specialName << vecType.getDimSize(0);
735 
736  Type elementType = vecType.getElementType();
737 
738  if (llvm::isa<IntegerType>(elementType) ||
739  llvm::isa<FloatType>(elementType)) {
740  specialName << "x" << elementType;
741  }
742  }
743 
744  setNameFn(getResult(), specialName.str());
745 }
746 
747 void mlir::spirv::AddressOfOp::getAsmResultNames(
748  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
749  SmallString<32> specialNameBuffer;
750  llvm::raw_svector_ostream specialName(specialNameBuffer);
751  specialName << getVariable() << "_addr";
752  setNameFn(getResult(), specialName.str());
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // spirv.ControlBarrierOp
757 //===----------------------------------------------------------------------===//
758 
759 LogicalResult spirv::ControlBarrierOp::verify() {
760  return verifyMemorySemantics(getOperation(), getMemorySemantics());
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // spirv.EntryPoint
765 //===----------------------------------------------------------------------===//
766 
767 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
768  spirv::ExecutionModel executionModel,
769  spirv::FuncOp function,
770  ArrayRef<Attribute> interfaceVars) {
771  build(builder, state,
772  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
773  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
774 }
775 
776 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
777  OperationState &result) {
778  spirv::ExecutionModel execModel;
779  SmallVector<Attribute, 4> interfaceVars;
780 
782  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
783  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
784  return failure();
785  }
786 
787  if (!parser.parseOptionalComma()) {
788  // Parse the interface variables
789  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
790  // The name of the interface variable attribute isnt important
791  FlatSymbolRefAttr var;
792  NamedAttrList attrs;
793  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
794  return failure();
795  interfaceVars.push_back(var);
796  return success();
797  }))
798  return failure();
799  }
800  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
801  parser.getBuilder().getArrayAttr(interfaceVars));
802  return success();
803 }
804 
806  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
807  printer.printSymbolName(getFn());
808  auto interfaceVars = getInterface().getValue();
809  if (!interfaceVars.empty())
810  printer << ", " << llvm::interleaved(interfaceVars);
811 }
812 
813 LogicalResult spirv::EntryPointOp::verify() {
814  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
815  // verification.
816  return success();
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // spirv.ExecutionMode
821 //===----------------------------------------------------------------------===//
822 
823 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
824  spirv::FuncOp function,
825  spirv::ExecutionMode executionMode,
826  ArrayRef<int32_t> params) {
827  build(builder, state, SymbolRefAttr::get(function),
828  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
829  builder.getI32ArrayAttr(params));
830 }
831 
832 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
833  OperationState &result) {
834  spirv::ExecutionMode execMode;
835  Attribute fn;
836  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
837  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
838  return failure();
839  }
840 
842  Type i32Type = parser.getBuilder().getIntegerType(32);
843  while (!parser.parseOptionalComma()) {
844  NamedAttrList attr;
845  Attribute value;
846  if (parser.parseAttribute(value, i32Type, "value", attr)) {
847  return failure();
848  }
849  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
850  }
851  StringRef valuesAttrName =
852  spirv::ExecutionModeOp::getValuesAttrName(result.name);
853  result.addAttribute(valuesAttrName,
854  parser.getBuilder().getI32ArrayAttr(values));
855  return success();
856 }
857 
859  printer << " ";
860  printer.printSymbolName(getFn());
861  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
862  ArrayAttr values = this->getValues();
863  if (!values.empty())
864  printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // spirv.func
869 //===----------------------------------------------------------------------===//
870 
871 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
873  SmallVector<DictionaryAttr> resultAttrs;
874  SmallVector<Type> resultTypes;
875  auto &builder = parser.getBuilder();
876 
877  // Parse the name as a symbol.
878  StringAttr nameAttr;
879  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
880  result.attributes))
881  return failure();
882 
883  // Parse the function signature.
884  bool isVariadic = false;
886  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
887  resultAttrs))
888  return failure();
889 
890  SmallVector<Type> argTypes;
891  for (auto &arg : entryArgs)
892  argTypes.push_back(arg.type);
893  auto fnType = builder.getFunctionType(argTypes, resultTypes);
894  result.addAttribute(getFunctionTypeAttrName(result.name),
895  TypeAttr::get(fnType));
896 
897  // Parse the optional function control keyword.
898  spirv::FunctionControl fnControl;
899  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
900  return failure();
901 
902  // If additional attributes are present, parse them.
903  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
904  return failure();
905 
906  // Add the attributes to the function arguments.
907  assert(resultAttrs.size() == resultTypes.size());
909  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
910  getResAttrsAttrName(result.name));
911 
912  // Parse the optional function body.
913  auto *body = result.addRegion();
914  OptionalParseResult parseResult =
915  parser.parseOptionalRegion(*body, entryArgs);
916  return failure(parseResult.has_value() && failed(*parseResult));
917 }
918 
919 void spirv::FuncOp::print(OpAsmPrinter &printer) {
920  // Print function name, signature, and control.
921  printer << " ";
922  printer.printSymbolName(getSymName());
923  auto fnType = getFunctionType();
925  printer, *this, fnType.getInputs(),
926  /*isVariadic=*/false, fnType.getResults());
927  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
928  << "\"";
930  printer, *this,
931  {spirv::attributeName<spirv::FunctionControl>(),
932  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
933  getFunctionControlAttrName()});
934 
935  // Print the body if this is not an external function.
936  Region &body = this->getBody();
937  if (!body.empty()) {
938  printer << ' ';
939  printer.printRegion(body, /*printEntryBlockArgs=*/false,
940  /*printBlockTerminators=*/true);
941  }
942 }
943 
944 LogicalResult spirv::FuncOp::verifyType() {
945  FunctionType fnType = getFunctionType();
946  if (fnType.getNumResults() > 1)
947  return emitOpError("cannot have more than one result");
948 
949  auto hasDecorationAttr = [&](spirv::Decoration decoration,
950  unsigned argIndex) {
951  auto func = llvm::cast<FunctionOpInterface>(getOperation());
952  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
953  if (argAttr.getName() != spirv::DecorationAttr::name)
954  continue;
955  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
956  return decAttr.getValue() == decoration;
957  }
958  return false;
959  };
960 
961  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
962  Type param = fnType.getInputs()[i];
963  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
964  if (!inputPtrType)
965  continue;
966 
967  auto pointeePtrType =
968  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
969  if (pointeePtrType) {
970  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
971  // > If an OpFunctionParameter is a pointer (or contains a pointer)
972  // > and the type it points to is a pointer in the PhysicalStorageBuffer
973  // > storage class, the function parameter must be decorated with exactly
974  // > one of AliasedPointer or RestrictPointer.
975  if (pointeePtrType.getStorageClass() !=
976  spirv::StorageClass::PhysicalStorageBuffer)
977  continue;
978 
979  bool hasAliasedPtr =
980  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
981  bool hasRestrictPtr =
982  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
983  if (!hasAliasedPtr && !hasRestrictPtr)
984  return emitOpError()
985  << "with a pointer points to a physical buffer pointer must "
986  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
987  continue;
988  }
989  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
990  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
991  // > the PhysicalStorageBuffer storage class, the function parameter must
992  // > be decorated with exactly one of Aliased or Restrict.
993  if (auto pointeeArrayType =
994  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
995  pointeePtrType =
996  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
997  } else {
998  pointeePtrType = inputPtrType;
999  }
1000 
1001  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1002  spirv::StorageClass::PhysicalStorageBuffer)
1003  continue;
1004 
1005  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1006  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1007  if (!hasAliased && !hasRestrict)
1008  return emitOpError() << "with physical buffer pointer must be decorated "
1009  "either 'Aliased' or 'Restrict'";
1010  }
1011 
1012  return success();
1013 }
1014 
1015 LogicalResult spirv::FuncOp::verifyBody() {
1016  FunctionType fnType = getFunctionType();
1017  if (!isExternal()) {
1018  Block &entryBlock = front();
1019 
1020  unsigned numArguments = this->getNumArguments();
1021  if (entryBlock.getNumArguments() != numArguments)
1022  return emitOpError("entry block must have ")
1023  << numArguments << " arguments to match function signature";
1024 
1025  for (auto [index, fnArgType, blockArgType] :
1026  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1027  if (blockArgType != fnArgType) {
1028  return emitOpError("type of entry block argument #")
1029  << index << '(' << blockArgType
1030  << ") must match the type of the corresponding argument in "
1031  << "function signature(" << fnArgType << ')';
1032  }
1033  }
1034  }
1035 
1036  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1037  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1038  if (fnType.getNumResults() != 0)
1039  return retOp.emitOpError("cannot be used in functions returning value");
1040  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1041  if (fnType.getNumResults() != 1)
1042  return retOp.emitOpError(
1043  "returns 1 value but enclosing function requires ")
1044  << fnType.getNumResults() << " results";
1045 
1046  auto retOperandType = retOp.getValue().getType();
1047  auto fnResultType = fnType.getResult(0);
1048  if (retOperandType != fnResultType)
1049  return retOp.emitOpError(" return value's type (")
1050  << retOperandType << ") mismatch with function's result type ("
1051  << fnResultType << ")";
1052  }
1053  return WalkResult::advance();
1054  });
1055 
1056  // TODO: verify other bits like linkage type.
1057 
1058  return failure(walkResult.wasInterrupted());
1059 }
1060 
1061 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1062  StringRef name, FunctionType type,
1063  spirv::FunctionControl control,
1064  ArrayRef<NamedAttribute> attrs) {
1065  state.addAttribute(SymbolTable::getSymbolAttrName(),
1066  builder.getStringAttr(name));
1067  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1068  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1069  builder.getAttr<spirv::FunctionControlAttr>(control));
1070  state.attributes.append(attrs.begin(), attrs.end());
1071  state.addRegion();
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // spirv.GLFClampOp
1076 //===----------------------------------------------------------------------===//
1077 
1078 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1079  OperationState &result) {
1080  return parseOneResultSameOperandTypeOp(parser, result);
1081 }
1083 
1084 //===----------------------------------------------------------------------===//
1085 // spirv.GLUClampOp
1086 //===----------------------------------------------------------------------===//
1087 
1088 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1089  OperationState &result) {
1090  return parseOneResultSameOperandTypeOp(parser, result);
1091 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // spirv.GLSClampOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1099  OperationState &result) {
1100  return parseOneResultSameOperandTypeOp(parser, result);
1101 }
1103 
1104 //===----------------------------------------------------------------------===//
1105 // spirv.GLFmaOp
1106 //===----------------------------------------------------------------------===//
1107 
1108 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1109  return parseOneResultSameOperandTypeOp(parser, result);
1110 }
1111 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // spirv.GlobalVariable
1115 //===----------------------------------------------------------------------===//
1116 
1117 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1118  Type type, StringRef name,
1119  unsigned descriptorSet, unsigned binding) {
1120  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1121  state.addAttribute(
1122  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1123  builder.getI32IntegerAttr(descriptorSet));
1124  state.addAttribute(
1125  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1126  builder.getI32IntegerAttr(binding));
1127 }
1128 
1129 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1130  Type type, StringRef name,
1131  spirv::BuiltIn builtin) {
1132  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1133  state.addAttribute(
1134  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1135  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1136 }
1137 
1138 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1139  OperationState &result) {
1140  // Parse variable name.
1141  StringAttr nameAttr;
1142  StringRef initializerAttrName =
1143  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1144  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1145  result.attributes)) {
1146  return failure();
1147  }
1148 
1149  // Parse optional initializer
1150  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1151  FlatSymbolRefAttr initSymbol;
1152  if (parser.parseLParen() ||
1153  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1154  result.attributes) ||
1155  parser.parseRParen())
1156  return failure();
1157  }
1158 
1159  if (parseVariableDecorations(parser, result)) {
1160  return failure();
1161  }
1162 
1163  Type type;
1164  StringRef typeAttrName =
1165  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1166  auto loc = parser.getCurrentLocation();
1167  if (parser.parseColonType(type)) {
1168  return failure();
1169  }
1170  if (!llvm::isa<spirv::PointerType>(type)) {
1171  return parser.emitError(loc, "expected spirv.ptr type");
1172  }
1173  result.addAttribute(typeAttrName, TypeAttr::get(type));
1174 
1175  return success();
1176 }
1177 
1179  SmallVector<StringRef, 4> elidedAttrs{
1180  spirv::attributeName<spirv::StorageClass>()};
1181 
1182  // Print variable name.
1183  printer << ' ';
1184  printer.printSymbolName(getSymName());
1185  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1186 
1187  StringRef initializerAttrName = this->getInitializerAttrName();
1188  // Print optional initializer
1189  if (auto initializer = this->getInitializer()) {
1190  printer << " " << initializerAttrName << '(';
1191  printer.printSymbolName(*initializer);
1192  printer << ')';
1193  elidedAttrs.push_back(initializerAttrName);
1194  }
1195 
1196  StringRef typeAttrName = this->getTypeAttrName();
1197  elidedAttrs.push_back(typeAttrName);
1198  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1199  printer << " : " << getType();
1200 }
1201 
1202 LogicalResult spirv::GlobalVariableOp::verify() {
1203  if (!llvm::isa<spirv::PointerType>(getType()))
1204  return emitOpError("result must be of a !spv.ptr type");
1205 
1206  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1207  // object. It cannot be Generic. It must be the same as the Storage Class
1208  // operand of the Result Type."
1209  // Also, Function storage class is reserved by spirv.Variable.
1210  auto storageClass = this->storageClass();
1211  if (storageClass == spirv::StorageClass::Generic ||
1212  storageClass == spirv::StorageClass::Function) {
1213  return emitOpError("storage class cannot be '")
1214  << stringifyStorageClass(storageClass) << "'";
1215  }
1216 
1217  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1218  this->getInitializerAttrName())) {
1220  (*this)->getParentOp(), init.getAttr());
1221  // TODO: Currently only variable initialization with specialization
1222  // constants and other variables is supported. They could be normal
1223  // constants in the module scope as well.
1224  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1225  spirv::SpecConstantCompositeOp>(initOp)) {
1226  return emitOpError("initializer must be result of a "
1227  "spirv.SpecConstant or spirv.GlobalVariable or "
1228  "spirv.SpecConstantCompositeOp op");
1229  }
1230  }
1231 
1232  return success();
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // spirv.INTEL.SubgroupBlockRead
1237 //===----------------------------------------------------------------------===//
1238 
1240  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1241  return failure();
1242 
1243  return success();
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // spirv.INTEL.SubgroupBlockWrite
1248 //===----------------------------------------------------------------------===//
1249 
1251  OperationState &result) {
1252  // Parse the storage class specification
1253  spirv::StorageClass storageClass;
1255  auto loc = parser.getCurrentLocation();
1256  Type elementType;
1257  if (parseEnumStrAttr(storageClass, parser) ||
1258  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1259  parser.parseType(elementType)) {
1260  return failure();
1261  }
1262 
1263  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1264  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1265  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1266 
1267  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1268  result.operands)) {
1269  return failure();
1270  }
1271  return success();
1272 }
1273 
1275  printer << " " << getPtr() << ", " << getValue() << " : "
1276  << getValue().getType();
1277 }
1278 
1280  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1281  return failure();
1282 
1283  return success();
1284 }
1285 
1286 //===----------------------------------------------------------------------===//
1287 // spirv.IAddCarryOp
1288 //===----------------------------------------------------------------------===//
1289 
1290 LogicalResult spirv::IAddCarryOp::verify() {
1292 }
1293 
1294 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1295  OperationState &result) {
1297 }
1298 
1299 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1300  ::printArithmeticExtendedBinaryOp(*this, printer);
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // spirv.ISubBorrowOp
1305 //===----------------------------------------------------------------------===//
1306 
1307 LogicalResult spirv::ISubBorrowOp::verify() {
1309 }
1310 
1311 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1312  OperationState &result) {
1314 }
1315 
1317  ::printArithmeticExtendedBinaryOp(*this, printer);
1318 }
1319 
1320 //===----------------------------------------------------------------------===//
1321 // spirv.SMulExtended
1322 //===----------------------------------------------------------------------===//
1323 
1324 LogicalResult spirv::SMulExtendedOp::verify() {
1326 }
1327 
1328 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1329  OperationState &result) {
1331 }
1332 
1334  ::printArithmeticExtendedBinaryOp(*this, printer);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // spirv.UMulExtended
1339 //===----------------------------------------------------------------------===//
1340 
1341 LogicalResult spirv::UMulExtendedOp::verify() {
1343 }
1344 
1345 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1346  OperationState &result) {
1348 }
1349 
1351  ::printArithmeticExtendedBinaryOp(*this, printer);
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // spirv.MemoryBarrierOp
1356 //===----------------------------------------------------------------------===//
1357 
1358 LogicalResult spirv::MemoryBarrierOp::verify() {
1359  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // spirv.module
1364 //===----------------------------------------------------------------------===//
1365 
1366 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1367  std::optional<StringRef> name) {
1368  OpBuilder::InsertionGuard guard(builder);
1369  builder.createBlock(state.addRegion());
1370  if (name) {
1371  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1372  builder.getStringAttr(*name));
1373  }
1374 }
1375 
1376 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1377  spirv::AddressingModel addressingModel,
1378  spirv::MemoryModel memoryModel,
1379  std::optional<VerCapExtAttr> vceTriple,
1380  std::optional<StringRef> name) {
1381  state.addAttribute(
1382  "addressing_model",
1383  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1384  state.addAttribute("memory_model",
1385  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1386  OpBuilder::InsertionGuard guard(builder);
1387  builder.createBlock(state.addRegion());
1388  if (vceTriple)
1389  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1390  if (name)
1391  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1392  builder.getStringAttr(*name));
1393 }
1394 
1395 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1396  OperationState &result) {
1397  Region *body = result.addRegion();
1398 
1399  // If the name is present, parse it.
1400  StringAttr nameAttr;
1401  (void)parser.parseOptionalSymbolName(
1402  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1403 
1404  // Parse attributes
1405  spirv::AddressingModel addrModel;
1406  spirv::MemoryModel memoryModel;
1407  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1408  result) ||
1409  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1410  result))
1411  return failure();
1412 
1413  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1414  spirv::VerCapExtAttr vceTriple;
1415  if (parser.parseAttribute(vceTriple,
1416  spirv::ModuleOp::getVCETripleAttrName(),
1417  result.attributes))
1418  return failure();
1419  }
1420 
1421  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1422  parser.parseRegion(*body, /*arguments=*/{}))
1423  return failure();
1424 
1425  // Make sure we have at least one block.
1426  if (body->empty())
1427  body->push_back(new Block());
1428 
1429  return success();
1430 }
1431 
1432 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1433  if (std::optional<StringRef> name = getName()) {
1434  printer << ' ';
1435  printer.printSymbolName(*name);
1436  }
1437 
1438  SmallVector<StringRef, 2> elidedAttrs;
1439 
1440  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1441  << spirv::stringifyMemoryModel(getMemoryModel());
1442  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1443  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1444  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1446 
1447  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1448  printer << " requires " << *triple;
1449  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1450  }
1451 
1452  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1453  printer << ' ';
1454  printer.printRegion(getRegion());
1455 }
1456 
1457 LogicalResult spirv::ModuleOp::verifyRegions() {
1458  Dialect *dialect = (*this)->getDialect();
1460  entryPoints;
1461  mlir::SymbolTable table(*this);
1462 
1463  for (auto &op : *getBody()) {
1464  if (op.getDialect() != dialect)
1465  return op.emitError("'spirv.module' can only contain spirv.* ops");
1466 
1467  // For EntryPoint op, check that the function and execution model is not
1468  // duplicated in EntryPointOps. Also verify that the interface specified
1469  // comes from globalVariables here to make this check cheaper.
1470  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1471  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1472  if (!funcOp) {
1473  return entryPointOp.emitError("function '")
1474  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1475  }
1476  if (auto interface = entryPointOp.getInterface()) {
1477  for (Attribute varRef : interface) {
1478  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1479  if (!varSymRef) {
1480  return entryPointOp.emitError(
1481  "expected symbol reference for interface "
1482  "specification instead of '")
1483  << varRef;
1484  }
1485  auto variableOp =
1486  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1487  if (!variableOp) {
1488  return entryPointOp.emitError("expected spirv.GlobalVariable "
1489  "symbol reference instead of'")
1490  << varSymRef << "'";
1491  }
1492  }
1493  }
1494 
1495  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1496  funcOp, entryPointOp.getExecutionModel());
1497  if (!entryPoints.try_emplace(key, entryPointOp).second)
1498  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1499  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1500  // If the function is external and does not have 'Import'
1501  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1502  // LinkageAttributes is used to import external functions.
1503  auto linkageAttr = funcOp.getLinkageAttributes();
1504  auto hasImportLinkage =
1505  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1506  spirv::LinkageType::Import);
1507  if (funcOp.isExternal() && !hasImportLinkage)
1508  return op.emitError(
1509  "'spirv.module' cannot contain external functions "
1510  "without 'Import' linkage_attributes (LinkageAttributes)");
1511 
1512  // TODO: move this check to spirv.func.
1513  for (auto &block : funcOp)
1514  for (auto &op : block) {
1515  if (op.getDialect() != dialect)
1516  return op.emitError(
1517  "functions in 'spirv.module' can only contain spirv.* ops");
1518  }
1519  }
1520  }
1521 
1522  return success();
1523 }
1524 
1525 //===----------------------------------------------------------------------===//
1526 // spirv.mlir.referenceof
1527 //===----------------------------------------------------------------------===//
1528 
1529 LogicalResult spirv::ReferenceOfOp::verify() {
1530  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1531  (*this)->getParentOp(), getSpecConstAttr());
1532  Type constType;
1533 
1534  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1535  if (specConstOp)
1536  constType = specConstOp.getDefaultValue().getType();
1537 
1538  auto specConstCompositeOp =
1539  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1540  if (specConstCompositeOp)
1541  constType = specConstCompositeOp.getType();
1542 
1543  if (!specConstOp && !specConstCompositeOp)
1544  return emitOpError(
1545  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1546 
1547  if (getReference().getType() != constType)
1548  return emitOpError("result type mismatch with the referenced "
1549  "specialization constant's type");
1550 
1551  return success();
1552 }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // spirv.SpecConstant
1556 //===----------------------------------------------------------------------===//
1557 
1558 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1559  OperationState &result) {
1560  StringAttr nameAttr;
1561  Attribute valueAttr;
1562  StringRef defaultValueAttrName =
1563  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1564 
1565  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1566  result.attributes))
1567  return failure();
1568 
1569  // Parse optional spec_id.
1570  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1571  IntegerAttr specIdAttr;
1572  if (parser.parseLParen() ||
1573  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1574  parser.parseRParen())
1575  return failure();
1576  }
1577 
1578  if (parser.parseEqual() ||
1579  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1580  return failure();
1581 
1582  return success();
1583 }
1584 
1586  printer << ' ';
1587  printer.printSymbolName(getSymName());
1588  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1589  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1590  printer << " = " << getDefaultValue();
1591 }
1592 
1593 LogicalResult spirv::SpecConstantOp::verify() {
1594  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1595  if (specID.getValue().isNegative())
1596  return emitOpError("SpecId cannot be negative");
1597 
1598  auto value = getDefaultValue();
1599  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1600  // Make sure bitwidth is allowed.
1601  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1602  return emitOpError("default value bitwidth disallowed");
1603  return success();
1604  }
1605  return emitOpError(
1606  "default value can only be a bool, integer, or float scalar");
1607 }
1608 
1609 //===----------------------------------------------------------------------===//
1610 // spirv.VectorShuffle
1611 //===----------------------------------------------------------------------===//
1612 
1613 LogicalResult spirv::VectorShuffleOp::verify() {
1614  VectorType resultType = llvm::cast<VectorType>(getType());
1615 
1616  size_t numResultElements = resultType.getNumElements();
1617  if (numResultElements != getComponents().size())
1618  return emitOpError("result type element count (")
1619  << numResultElements
1620  << ") mismatch with the number of component selectors ("
1621  << getComponents().size() << ")";
1622 
1623  size_t totalSrcElements =
1624  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1625  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1626 
1627  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1628  uint32_t index = selector.getZExtValue();
1629  if (index >= totalSrcElements &&
1630  index != std::numeric_limits<uint32_t>().max())
1631  return emitOpError("component selector ")
1632  << index << " out of range: expected to be in [0, "
1633  << totalSrcElements << ") or 0xffffffff";
1634  }
1635  return success();
1636 }
1637 
1638 //===----------------------------------------------------------------------===//
1639 // spirv.MatrixTimesScalar
1640 //===----------------------------------------------------------------------===//
1641 
1642 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1643  Type elementType =
1644  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1646  [](auto matrixType) { return matrixType.getElementType(); })
1647  .Default([](Type) { return nullptr; });
1648 
1649  assert(elementType && "Unhandled type");
1650 
1651  // Check that the scalar type is the same as the matrix element type.
1652  if (getScalar().getType() != elementType)
1653  return emitOpError("input matrix components' type and scaling value must "
1654  "have the same type");
1655 
1656  return success();
1657 }
1658 
1659 //===----------------------------------------------------------------------===//
1660 // spirv.Transpose
1661 //===----------------------------------------------------------------------===//
1662 
1663 LogicalResult spirv::TransposeOp::verify() {
1664  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1665  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1666 
1667  // Verify that the input and output matrices have correct shapes.
1668  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1669  return emitError("input matrix rows count must be equal to "
1670  "output matrix columns count");
1671 
1672  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1673  return emitError("input matrix columns count must be equal to "
1674  "output matrix rows count");
1675 
1676  // Verify that the input and output matrices have the same component type
1677  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1678  return emitError("input and output matrices must have the same "
1679  "component type");
1680 
1681  return success();
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // spirv.MatrixTimesVector
1686 //===----------------------------------------------------------------------===//
1687 
1688 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1689  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1690  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1691  auto resultType = llvm::cast<VectorType>(getType());
1692 
1693  if (matrixType.getNumColumns() != vectorType.getNumElements())
1694  return emitOpError("matrix columns (")
1695  << matrixType.getNumColumns() << ") must match vector operand size ("
1696  << vectorType.getNumElements() << ")";
1697 
1698  if (resultType.getNumElements() != matrixType.getNumRows())
1699  return emitOpError("result size (")
1700  << resultType.getNumElements() << ") must match the matrix rows ("
1701  << matrixType.getNumRows() << ")";
1702 
1703  if (matrixType.getElementType() != resultType.getElementType())
1704  return emitOpError("matrix and result element types must match");
1705 
1706  return success();
1707 }
1708 
1709 //===----------------------------------------------------------------------===//
1710 // spirv.VectorTimesMatrix
1711 //===----------------------------------------------------------------------===//
1712 
1713 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1714  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1715  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1716  auto resultType = llvm::cast<VectorType>(getType());
1717 
1718  if (matrixType.getNumRows() != vectorType.getNumElements())
1719  return emitOpError("number of components in vector must equal the number "
1720  "of components in each column in matrix");
1721 
1722  if (resultType.getNumElements() != matrixType.getNumColumns())
1723  return emitOpError("number of columns in matrix must equal the number of "
1724  "components in result");
1725 
1726  if (matrixType.getElementType() != resultType.getElementType())
1727  return emitOpError("matrix must be a matrix with the same component type "
1728  "as the component type in result");
1729 
1730  return success();
1731 }
1732 
1733 //===----------------------------------------------------------------------===//
1734 // spirv.MatrixTimesMatrix
1735 //===----------------------------------------------------------------------===//
1736 
1737 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1738  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1739  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1740  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1741 
1742  // left matrix columns' count and right matrix rows' count must be equal
1743  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1744  return emitError("left matrix columns' count must be equal to "
1745  "the right matrix rows' count");
1746 
1747  // right and result matrices columns' count must be the same
1748  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1749  return emitError(
1750  "right and result matrices must have equal columns' count");
1751 
1752  // right and result matrices component type must be the same
1753  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1754  return emitError("right and result matrices' component type must"
1755  " be the same");
1756 
1757  // left and result matrices component type must be the same
1758  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1759  return emitError("left and result matrices' component type"
1760  " must be the same");
1761 
1762  // left and result matrices rows count must be the same
1763  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1764  return emitError("left and result matrices must have equal rows' count");
1765 
1766  return success();
1767 }
1768 
1769 //===----------------------------------------------------------------------===//
1770 // spirv.SpecConstantComposite
1771 //===----------------------------------------------------------------------===//
1772 
1774  OperationState &result) {
1775 
1776  StringAttr compositeName;
1777  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1778  result.attributes))
1779  return failure();
1780 
1781  if (parser.parseLParen())
1782  return failure();
1783 
1784  SmallVector<Attribute, 4> constituents;
1785 
1786  do {
1787  // The name of the constituent attribute isn't important
1788  const char *attrName = "spec_const";
1789  FlatSymbolRefAttr specConstRef;
1790  NamedAttrList attrs;
1791 
1792  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1793  return failure();
1794 
1795  constituents.push_back(specConstRef);
1796  } while (!parser.parseOptionalComma());
1797 
1798  if (parser.parseRParen())
1799  return failure();
1800 
1801  StringAttr compositeSpecConstituentsName =
1802  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1803  result.addAttribute(compositeSpecConstituentsName,
1804  parser.getBuilder().getArrayAttr(constituents));
1805 
1806  Type type;
1807  if (parser.parseColonType(type))
1808  return failure();
1809 
1810  StringAttr typeAttrName =
1811  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1812  result.addAttribute(typeAttrName, TypeAttr::get(type));
1813 
1814  return success();
1815 }
1816 
1818  printer << " ";
1819  printer.printSymbolName(getSymName());
1820  printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1821  << ") : " << getType();
1822 }
1823 
1825  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1826  auto constituents = this->getConstituents().getValue();
1827 
1828  if (!cType)
1829  return emitError("result type must be a composite type, but provided ")
1830  << getType();
1831 
1832  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1833  return emitError("unsupported composite type ") << cType;
1834  if (constituents.size() != cType.getNumElements())
1835  return emitError("has incorrect number of operands: expected ")
1836  << cType.getNumElements() << ", but provided "
1837  << constituents.size();
1838 
1839  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1840  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1841 
1842  auto constituentSpecConstOp =
1843  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1844  (*this)->getParentOp(), constituent.getAttr()));
1845 
1846  if (constituentSpecConstOp.getDefaultValue().getType() !=
1847  cType.getElementType(index))
1848  return emitError("has incorrect types of operands: expected ")
1849  << cType.getElementType(index) << ", but provided "
1850  << constituentSpecConstOp.getDefaultValue().getType();
1851  }
1852 
1853  return success();
1854 }
1855 
1856 //===----------------------------------------------------------------------===//
1857 // spirv.SpecConstantOperation
1858 //===----------------------------------------------------------------------===//
1859 
1861  OperationState &result) {
1862  Region *body = result.addRegion();
1863 
1864  if (parser.parseKeyword("wraps"))
1865  return failure();
1866 
1867  body->push_back(new Block);
1868  Block &block = body->back();
1869  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1870 
1871  if (!wrappedOp)
1872  return failure();
1873 
1874  OpBuilder builder(parser.getContext());
1875  builder.setInsertionPointToEnd(&block);
1876  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1877  result.location = wrappedOp->getLoc();
1878 
1879  result.addTypes(wrappedOp->getResult(0).getType());
1880 
1881  if (parser.parseOptionalAttrDict(result.attributes))
1882  return failure();
1883 
1884  return success();
1885 }
1886 
1888  printer << " wraps ";
1889  printer.printGenericOp(&getBody().front().front());
1890 }
1891 
1892 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1893  Block &block = getRegion().getBlocks().front();
1894 
1895  if (block.getOperations().size() != 2)
1896  return emitOpError("expected exactly 2 nested ops");
1897 
1898  Operation &enclosedOp = block.getOperations().front();
1899 
1901  return emitOpError("invalid enclosed op");
1902 
1903  for (auto operand : enclosedOp.getOperands())
1904  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1905  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1906  return emitOpError(
1907  "invalid operand, must be defined by a constant operation");
1908 
1909  return success();
1910 }
1911 
1912 //===----------------------------------------------------------------------===//
1913 // spirv.GL.FrexpStruct
1914 //===----------------------------------------------------------------------===//
1915 
1916 LogicalResult spirv::GLFrexpStructOp::verify() {
1917  spirv::StructType structTy =
1918  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1919 
1920  if (structTy.getNumElements() != 2)
1921  return emitError("result type must be a struct type with two memebers");
1922 
1923  Type significandTy = structTy.getElementType(0);
1924  Type exponentTy = structTy.getElementType(1);
1925  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1926  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1927 
1928  Type operandTy = getOperand().getType();
1929  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1930  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1931 
1932  if (significandTy != operandTy)
1933  return emitError("member zero of the resulting struct type must be the "
1934  "same type as the operand");
1935 
1936  if (exponentVecTy) {
1937  IntegerType componentIntTy =
1938  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1939  if (!componentIntTy || componentIntTy.getWidth() != 32)
1940  return emitError("member one of the resulting struct type must"
1941  "be a scalar or vector of 32 bit integer type");
1942  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1943  return emitError("member one of the resulting struct type "
1944  "must be a scalar or vector of 32 bit integer type");
1945  }
1946 
1947  // Check that the two member types have the same number of components
1948  if (operandVecTy && exponentVecTy &&
1949  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1950  return success();
1951 
1952  if (operandFTy && exponentIntTy)
1953  return success();
1954 
1955  return emitError("member one of the resulting struct type must have the same "
1956  "number of components as the operand type");
1957 }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // spirv.GL.Ldexp
1961 //===----------------------------------------------------------------------===//
1962 
1963 LogicalResult spirv::GLLdexpOp::verify() {
1964  Type significandType = getX().getType();
1965  Type exponentType = getExp().getType();
1966 
1967  if (llvm::isa<FloatType>(significandType) !=
1968  llvm::isa<IntegerType>(exponentType))
1969  return emitOpError("operands must both be scalars or vectors");
1970 
1971  auto getNumElements = [](Type type) -> unsigned {
1972  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1973  return vectorType.getNumElements();
1974  return 1;
1975  };
1976 
1977  if (getNumElements(significandType) != getNumElements(exponentType))
1978  return emitOpError("operands must have the same number of elements");
1979 
1980  return success();
1981 }
1982 
1983 //===----------------------------------------------------------------------===//
1984 // spirv.ShiftLeftLogicalOp
1985 //===----------------------------------------------------------------------===//
1986 
1987 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
1988  return verifyShiftOp(*this);
1989 }
1990 
1991 //===----------------------------------------------------------------------===//
1992 // spirv.ShiftRightArithmeticOp
1993 //===----------------------------------------------------------------------===//
1994 
1995 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
1996  return verifyShiftOp(*this);
1997 }
1998 
1999 //===----------------------------------------------------------------------===//
2000 // spirv.ShiftRightLogicalOp
2001 //===----------------------------------------------------------------------===//
2002 
2003 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2004  return verifyShiftOp(*this);
2005 }
2006 
2007 //===----------------------------------------------------------------------===//
2008 // spirv.VectorTimesScalarOp
2009 //===----------------------------------------------------------------------===//
2010 
2011 LogicalResult spirv::VectorTimesScalarOp::verify() {
2012  if (getVector().getType() != getType())
2013  return emitOpError("vector operand and result type mismatch");
2014  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2015  if (getScalar().getType() != scalarType)
2016  return emitOpError("scalar operand and result element type match");
2017  return success();
2018 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:269
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:559
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:122
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:255
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:301
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:171
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:151
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:293
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:272
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
SPIR-V struct type.
Definition: SPIRVTypes.h:293
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:71
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.