MLIR  21.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/InterleavedRange.h"
39 #include <cassert>
40 #include <numeric>
41 #include <optional>
42 #include <type_traits>
43 
44 using namespace mlir;
45 using namespace mlir::spirv::AttrNames;
46 
47 //===----------------------------------------------------------------------===//
48 // Common utility functions
49 //===----------------------------------------------------------------------===//
50 
51 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
52  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53  if (!constOp) {
54  return failure();
55  }
56  auto valueAttr = constOp.getValue();
57  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58  if (!integerValueAttr) {
59  return failure();
60  }
61 
62  if (integerValueAttr.getType().isSignlessInteger())
63  value = integerValueAttr.getInt();
64  else
65  value = integerValueAttr.getSInt();
66 
67  return success();
68 }
69 
70 LogicalResult
72  spirv::MemorySemantics memorySemantics) {
73  // According to the SPIR-V specification:
74  // "Despite being a mask and allowing multiple bits to be combined, it is
75  // invalid for more than one of these four bits to be set: Acquire, Release,
76  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77  // Release semantics is done by setting the AcquireRelease bit, not by setting
78  // two bits."
79  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80  spirv::MemorySemantics::Release |
81  spirv::MemorySemantics::AcquireRelease |
82  spirv::MemorySemantics::SequentiallyConsistent;
83 
84  auto bitCount =
85  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86  if (bitCount > 1) {
87  return op->emitError(
88  "expected at most one of these four memory constraints "
89  "to be set: `Acquire`, `Release`,"
90  "`AcquireRelease` or `SequentiallyConsistent`");
91  }
92  return success();
93 }
94 
96  SmallVectorImpl<StringRef> &elidedAttrs) {
97  // Print optional descriptor binding
98  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::DescriptorSet));
100  auto bindingName = llvm::convertToSnakeFromCamelCase(
101  stringifyDecoration(spirv::Decoration::Binding));
102  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104  if (descriptorSet && binding) {
105  elidedAttrs.push_back(descriptorSetName);
106  elidedAttrs.push_back(bindingName);
107  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108  << ")";
109  }
110 
111  // Print BuiltIn attribute if present
112  auto builtInName = llvm::convertToSnakeFromCamelCase(
113  stringifyDecoration(spirv::Decoration::BuiltIn));
114  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116  elidedAttrs.push_back(builtInName);
117  }
118 
119  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
120 }
121 
122 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
123  OperationState &result) {
125  Type type;
126  // If the operand list is in-between parentheses, then we have a generic form.
127  // (see the fallback in `printOneResultOp`).
128  SMLoc loc = parser.getCurrentLocation();
129  if (!parser.parseOptionalLParen()) {
130  if (parser.parseOperandList(ops) || parser.parseRParen() ||
131  parser.parseOptionalAttrDict(result.attributes) ||
132  parser.parseColon() || parser.parseType(type))
133  return failure();
134  auto fnType = llvm::dyn_cast<FunctionType>(type);
135  if (!fnType) {
136  parser.emitError(loc, "expected function type");
137  return failure();
138  }
139  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140  return failure();
141  result.addTypes(fnType.getResults());
142  return success();
143  }
144  return failure(parser.parseOperandList(ops) ||
145  parser.parseOptionalAttrDict(result.attributes) ||
146  parser.parseColonType(type) ||
147  parser.resolveOperands(ops, type, result.operands) ||
148  parser.addTypeToList(type, result.types));
149 }
150 
152  assert(op->getNumResults() == 1 && "op should have one result");
153 
154  // If not all the operand and result types are the same, just use the
155  // generic assembly form to avoid omitting information in printing.
156  auto resultType = op->getResult(0).getType();
157  if (llvm::any_of(op->getOperandTypes(),
158  [&](Type type) { return type != resultType; })) {
159  p.printGenericOp(op, /*printOpName=*/false);
160  return;
161  }
162 
163  p << ' ';
164  p.printOperands(op->getOperands());
166  // Now we can output only one type for all operands and the result.
167  p << " : " << resultType;
168 }
169 
170 template <typename BlockReadWriteOpTy>
171 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
172  Value ptr, Value val) {
173  auto valType = val.getType();
174  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
175  valType = valVecTy.getElementType();
176 
177  if (valType !=
178  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
179  return op.emitOpError("mismatch in result type and pointer type");
180  }
181  return success();
182 }
183 
184 /// Walks the given type hierarchy with the given indices, potentially down
185 /// to component granularity, to select an element type. Returns null type and
186 /// emits errors with the given loc on failure.
187 static Type
189  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
190  if (indices.empty()) {
191  emitErrorFn("expected at least one index for spirv.CompositeExtract");
192  return nullptr;
193  }
194 
195  for (auto index : indices) {
196  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
197  if (cType.hasCompileTimeKnownNumElements() &&
198  (index < 0 ||
199  static_cast<uint64_t>(index) >= cType.getNumElements())) {
200  emitErrorFn("index ") << index << " out of bounds for " << type;
201  return nullptr;
202  }
203  type = cType.getElementType(index);
204  } else {
205  emitErrorFn("cannot extract from non-composite type ")
206  << type << " with index " << index;
207  return nullptr;
208  }
209  }
210  return type;
211 }
212 
213 static Type
215  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
216  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
217  if (!indicesArrayAttr) {
218  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
219  return nullptr;
220  }
221  if (indicesArrayAttr.empty()) {
222  emitErrorFn("expected at least one index for spirv.CompositeExtract");
223  return nullptr;
224  }
225 
226  SmallVector<int32_t, 2> indexVals;
227  for (auto indexAttr : indicesArrayAttr) {
228  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
229  if (!indexIntAttr) {
230  emitErrorFn("expected an 32-bit integer for index, but found '")
231  << indexAttr << "'";
232  return nullptr;
233  }
234  indexVals.push_back(indexIntAttr.getInt());
235  }
236  return getElementType(type, indexVals, emitErrorFn);
237 }
238 
239 static Type getElementType(Type type, Attribute indices, Location loc) {
240  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
241  return ::mlir::emitError(loc, err);
242  };
243  return getElementType(type, indices, errorFn);
244 }
245 
246 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
247  SMLoc loc) {
248  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
249  return parser.emitError(loc, err);
250  };
251  return getElementType(type, indices, errorFn);
252 }
253 
254 template <typename ExtendedBinaryOp>
255 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
256  auto resultType = llvm::cast<spirv::StructType>(op.getType());
257  if (resultType.getNumElements() != 2)
258  return op.emitOpError("expected result struct type containing two members");
259 
260  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
261  resultType.getElementType(0),
262  resultType.getElementType(1)}))
263  return op.emitOpError(
264  "expected all operand types and struct member types are the same");
265 
266  return success();
267 }
268 
269 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
270  OperationState &result) {
272  if (parser.parseOptionalAttrDict(result.attributes) ||
273  parser.parseOperandList(operands) || parser.parseColon())
274  return failure();
275 
276  Type resultType;
277  SMLoc loc = parser.getCurrentLocation();
278  if (parser.parseType(resultType))
279  return failure();
280 
281  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
282  if (!structType || structType.getNumElements() != 2)
283  return parser.emitError(loc, "expected spirv.struct type with two members");
284 
285  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
286  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
287  return failure();
288 
289  result.addTypes(resultType);
290  return success();
291 }
292 
294  OpAsmPrinter &printer) {
295  printer << ' ';
296  printer.printOptionalAttrDict(op->getAttrs());
297  printer.printOperands(op->getOperands());
298  printer << " : " << op->getResultTypes().front();
299 }
300 
301 static LogicalResult verifyShiftOp(Operation *op) {
302  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
303  return op->emitError("expected the same type for the first operand and "
304  "result, but provided ")
305  << op->getOperand(0).getType() << " and "
306  << op->getResult(0).getType();
307  }
308  return success();
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // spirv.mlir.addressof
313 //===----------------------------------------------------------------------===//
314 
315 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
316  spirv::GlobalVariableOp var) {
317  build(builder, state, var.getType(), SymbolRefAttr::get(var));
318 }
319 
320 LogicalResult spirv::AddressOfOp::verify() {
321  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
322  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
323  getVariableAttr()));
324  if (!varOp) {
325  return emitOpError("expected spirv.GlobalVariable symbol");
326  }
327  if (getPointer().getType() != varOp.getType()) {
328  return emitOpError(
329  "result type mismatch with the referenced global variable's type");
330  }
331  return success();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // spirv.CompositeConstruct
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult spirv::CompositeConstructOp::verify() {
339  operand_range constituents = this->getConstituents();
340 
341  // There are 4 cases with varying verification rules:
342  // 1. Cooperative Matrices (1 constituent)
343  // 2. Structs (1 constituent for each member)
344  // 3. Arrays (1 constituent for each array element)
345  // 4. Vectors (1 constituent (sub-)element for each vector element)
346 
347  auto coopElementType =
350  [](auto coopType) { return coopType.getElementType(); })
351  .Default([](Type) { return nullptr; });
352 
353  // Case 1. -- matrices.
354  if (coopElementType) {
355  if (constituents.size() != 1)
356  return emitOpError("has incorrect number of operands: expected ")
357  << "1, but provided " << constituents.size();
358  if (coopElementType != constituents.front().getType())
359  return emitOpError("operand type mismatch: expected operand type ")
360  << coopElementType << ", but provided "
361  << constituents.front().getType();
362  return success();
363  }
364 
365  // Case 2./3./4. -- number of constituents matches the number of elements.
366  auto cType = llvm::cast<spirv::CompositeType>(getType());
367  if (constituents.size() == cType.getNumElements()) {
368  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
369  if (constituents[index].getType() != cType.getElementType(index)) {
370  return emitOpError("operand type mismatch: expected operand type ")
371  << cType.getElementType(index) << ", but provided "
372  << constituents[index].getType();
373  }
374  }
375  return success();
376  }
377 
378  // Case 4. -- check that all constituents add up tp the expected vector type.
379  auto resultType = llvm::dyn_cast<VectorType>(cType);
380  if (!resultType)
381  return emitOpError(
382  "expected to return a vector or cooperative matrix when the number of "
383  "constituents is less than what the result needs");
384 
385  SmallVector<unsigned> sizes;
386  for (Value component : constituents) {
387  if (!llvm::isa<VectorType>(component.getType()) &&
388  !component.getType().isIntOrFloat())
389  return emitOpError("operand type mismatch: expected operand to have "
390  "a scalar or vector type, but provided ")
391  << component.getType();
392 
393  Type elementType = component.getType();
394  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
395  sizes.push_back(vectorType.getNumElements());
396  elementType = vectorType.getElementType();
397  } else {
398  sizes.push_back(1);
399  }
400 
401  if (elementType != resultType.getElementType())
402  return emitOpError("operand element type mismatch: expected to be ")
403  << resultType.getElementType() << ", but provided " << elementType;
404  }
405  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
406  if (totalCount != cType.getNumElements())
407  return emitOpError("has incorrect number of operands: expected ")
408  << cType.getNumElements() << ", but provided " << totalCount;
409  return success();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // spirv.CompositeExtractOp
414 //===----------------------------------------------------------------------===//
415 
416 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
417  Value composite,
418  ArrayRef<int32_t> indices) {
419  auto indexAttr = builder.getI32ArrayAttr(indices);
420  auto elementType =
421  getElementType(composite.getType(), indexAttr, state.location);
422  if (!elementType) {
423  return;
424  }
425  build(builder, state, elementType, composite, indexAttr);
426 }
427 
429  OperationState &result) {
430  OpAsmParser::UnresolvedOperand compositeInfo;
431  Attribute indicesAttr;
432  StringRef indicesAttrName =
433  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
434  Type compositeType;
435  SMLoc attrLocation;
436 
437  if (parser.parseOperand(compositeInfo) ||
438  parser.getCurrentLocation(&attrLocation) ||
439  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
440  parser.parseColonType(compositeType) ||
441  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
442  return failure();
443  }
444 
445  Type resultType =
446  getElementType(compositeType, indicesAttr, parser, attrLocation);
447  if (!resultType) {
448  return failure();
449  }
450  result.addTypes(resultType);
451  return success();
452 }
453 
455  printer << ' ' << getComposite() << getIndices() << " : "
456  << getComposite().getType();
457 }
458 
459 LogicalResult spirv::CompositeExtractOp::verify() {
460  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
461  auto resultType =
462  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
463  if (!resultType)
464  return failure();
465 
466  if (resultType != getType()) {
467  return emitOpError("invalid result type: expected ")
468  << resultType << " but provided " << getType();
469  }
470 
471  return success();
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // spirv.CompositeInsert
476 //===----------------------------------------------------------------------===//
477 
478 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
479  Value object, Value composite,
480  ArrayRef<int32_t> indices) {
481  auto indexAttr = builder.getI32ArrayAttr(indices);
482  build(builder, state, composite.getType(), object, composite, indexAttr);
483 }
484 
485 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
486  OperationState &result) {
488  Type objectType, compositeType;
489  Attribute indicesAttr;
490  StringRef indicesAttrName =
491  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
492  auto loc = parser.getCurrentLocation();
493 
494  return failure(
495  parser.parseOperandList(operands, 2) ||
496  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
497  parser.parseColonType(objectType) ||
498  parser.parseKeywordType("into", compositeType) ||
499  parser.resolveOperands(operands, {objectType, compositeType}, loc,
500  result.operands) ||
501  parser.addTypesToList(compositeType, result.types));
502 }
503 
504 LogicalResult spirv::CompositeInsertOp::verify() {
505  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
506  auto objectType =
507  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
508  if (!objectType)
509  return failure();
510 
511  if (objectType != getObject().getType()) {
512  return emitOpError("object operand type should be ")
513  << objectType << ", but found " << getObject().getType();
514  }
515 
516  if (getComposite().getType() != getType()) {
517  return emitOpError("result type should be the same as "
518  "the composite type, but found ")
519  << getComposite().getType() << " vs " << getType();
520  }
521 
522  return success();
523 }
524 
526  printer << " " << getObject() << ", " << getComposite() << getIndices()
527  << " : " << getObject().getType() << " into "
528  << getComposite().getType();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // spirv.Constant
533 //===----------------------------------------------------------------------===//
534 
535 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
536  OperationState &result) {
537  Attribute value;
538  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
539  if (parser.parseAttribute(value, valueAttrName, result.attributes))
540  return failure();
541 
542  Type type = NoneType::get(parser.getContext());
543  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
544  type = typedAttr.getType();
545  if (llvm::isa<NoneType, TensorType>(type)) {
546  if (parser.parseColonType(type))
547  return failure();
548  }
549 
550  if (llvm::isa<TensorArmType>(type)) {
551  if (parser.parseOptionalColon().succeeded())
552  if (parser.parseType(type))
553  return failure();
554  }
555 
556  return parser.addTypeToList(type, result.types);
557 }
558 
559 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
560  printer << ' ' << getValue();
561  if (llvm::isa<spirv::ArrayType>(getType()))
562  printer << " : " << getType();
563 }
564 
565 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
566  Type opType) {
567  if (isa<spirv::CooperativeMatrixType>(opType)) {
568  auto denseAttr = dyn_cast<DenseElementsAttr>(value);
569  if (!denseAttr || !denseAttr.isSplat())
570  return op.emitOpError("expected a splat dense attribute for cooperative "
571  "matrix constant, but found ")
572  << denseAttr;
573  }
574  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
575  auto valueType = llvm::cast<TypedAttr>(value).getType();
576  if (valueType != opType)
577  return op.emitOpError("result type (")
578  << opType << ") does not match value type (" << valueType << ")";
579  return success();
580  }
581  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
582  auto valueType = llvm::cast<TypedAttr>(value).getType();
583  if (valueType == opType)
584  return success();
585  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
586  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
587  if (!arrayType)
588  return op.emitOpError("result or element type (")
589  << opType << ") does not match value type (" << valueType
590  << "), must be the same or spirv.array";
591 
592  int numElements = arrayType.getNumElements();
593  auto opElemType = arrayType.getElementType();
594  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
595  numElements *= t.getNumElements();
596  opElemType = t.getElementType();
597  }
598  if (!opElemType.isIntOrFloat())
599  return op.emitOpError("only support nested array result type");
600 
601  auto valueElemType = shapedType.getElementType();
602  if (valueElemType != opElemType) {
603  return op.emitOpError("result element type (")
604  << opElemType << ") does not match value element type ("
605  << valueElemType << ")";
606  }
607 
608  if (numElements != shapedType.getNumElements()) {
609  return op.emitOpError("result number of elements (")
610  << numElements << ") does not match value number of elements ("
611  << shapedType.getNumElements() << ")";
612  }
613  return success();
614  }
615  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
616  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
617  if (!arrayType)
618  return op.emitOpError(
619  "must have spirv.array result type for array value");
620  Type elemType = arrayType.getElementType();
621  for (Attribute element : arrayAttr.getValue()) {
622  // Verify array elements recursively.
623  if (failed(verifyConstantType(op, element, elemType)))
624  return failure();
625  }
626  return success();
627  }
628  return op.emitOpError("cannot have attribute: ") << value;
629 }
630 
631 LogicalResult spirv::ConstantOp::verify() {
632  // ODS already generates checks to make sure the result type is valid. We just
633  // need to additionally check that the value's attribute type is consistent
634  // with the result type.
635  return verifyConstantType(*this, getValueAttr(), getType());
636 }
637 
638 bool spirv::ConstantOp::isBuildableWith(Type type) {
639  // Must be valid SPIR-V type first.
640  if (!llvm::isa<spirv::SPIRVType>(type))
641  return false;
642 
643  if (isa<SPIRVDialect>(type.getDialect())) {
644  // TODO: support constant struct
645  return llvm::isa<spirv::ArrayType>(type);
646  }
647 
648  return true;
649 }
650 
651 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
652  OpBuilder &builder) {
653  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
654  unsigned width = intType.getWidth();
655  if (width == 1)
656  return builder.create<spirv::ConstantOp>(loc, type,
657  builder.getBoolAttr(false));
658  return builder.create<spirv::ConstantOp>(
659  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
660  }
661  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
662  return builder.create<spirv::ConstantOp>(
663  loc, type, builder.getFloatAttr(floatType, 0.0));
664  }
665  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
666  Type elemType = vectorType.getElementType();
667  if (llvm::isa<IntegerType>(elemType)) {
668  return builder.create<spirv::ConstantOp>(
669  loc, type,
670  DenseElementsAttr::get(vectorType,
671  IntegerAttr::get(elemType, 0).getValue()));
672  }
673  if (llvm::isa<FloatType>(elemType)) {
674  return builder.create<spirv::ConstantOp>(
675  loc, type,
676  DenseFPElementsAttr::get(vectorType,
677  FloatAttr::get(elemType, 0.0).getValue()));
678  }
679  }
680 
681  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
682 }
683 
684 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
685  OpBuilder &builder) {
686  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
687  unsigned width = intType.getWidth();
688  if (width == 1)
689  return builder.create<spirv::ConstantOp>(loc, type,
690  builder.getBoolAttr(true));
691  return builder.create<spirv::ConstantOp>(
692  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
693  }
694  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
695  return builder.create<spirv::ConstantOp>(
696  loc, type, builder.getFloatAttr(floatType, 1.0));
697  }
698  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
699  Type elemType = vectorType.getElementType();
700  if (llvm::isa<IntegerType>(elemType)) {
701  return builder.create<spirv::ConstantOp>(
702  loc, type,
703  DenseElementsAttr::get(vectorType,
704  IntegerAttr::get(elemType, 1).getValue()));
705  }
706  if (llvm::isa<FloatType>(elemType)) {
707  return builder.create<spirv::ConstantOp>(
708  loc, type,
709  DenseFPElementsAttr::get(vectorType,
710  FloatAttr::get(elemType, 1.0).getValue()));
711  }
712  }
713 
714  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
715 }
716 
717 void mlir::spirv::ConstantOp::getAsmResultNames(
718  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
719  Type type = getType();
720 
721  SmallString<32> specialNameBuffer;
722  llvm::raw_svector_ostream specialName(specialNameBuffer);
723  specialName << "cst";
724 
725  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
726 
727  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
728  if (intTy && intTy.getWidth() == 1) {
729  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
730  }
731 
732  if (intTy.isSignless()) {
733  specialName << intCst.getInt();
734  } else if (intTy.isUnsigned()) {
735  specialName << intCst.getUInt();
736  } else {
737  specialName << intCst.getSInt();
738  }
739  }
740 
741  if (intTy || llvm::isa<FloatType>(type)) {
742  specialName << '_' << type;
743  }
744 
745  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
746  specialName << "_vec_";
747  specialName << vecType.getDimSize(0);
748 
749  Type elementType = vecType.getElementType();
750 
751  if (llvm::isa<IntegerType>(elementType) ||
752  llvm::isa<FloatType>(elementType)) {
753  specialName << "x" << elementType;
754  }
755  }
756 
757  setNameFn(getResult(), specialName.str());
758 }
759 
760 void mlir::spirv::AddressOfOp::getAsmResultNames(
761  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
762  SmallString<32> specialNameBuffer;
763  llvm::raw_svector_ostream specialName(specialNameBuffer);
764  specialName << getVariable() << "_addr";
765  setNameFn(getResult(), specialName.str());
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // spirv.ControlBarrierOp
770 //===----------------------------------------------------------------------===//
771 
772 LogicalResult spirv::ControlBarrierOp::verify() {
773  return verifyMemorySemantics(getOperation(), getMemorySemantics());
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // spirv.EntryPoint
778 //===----------------------------------------------------------------------===//
779 
780 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
781  spirv::ExecutionModel executionModel,
782  spirv::FuncOp function,
783  ArrayRef<Attribute> interfaceVars) {
784  build(builder, state,
785  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
786  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
787 }
788 
789 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
790  OperationState &result) {
791  spirv::ExecutionModel execModel;
792  SmallVector<Attribute, 4> interfaceVars;
793 
795  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
796  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
797  return failure();
798  }
799 
800  if (!parser.parseOptionalComma()) {
801  // Parse the interface variables
802  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
803  // The name of the interface variable attribute isnt important
804  FlatSymbolRefAttr var;
805  NamedAttrList attrs;
806  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
807  return failure();
808  interfaceVars.push_back(var);
809  return success();
810  }))
811  return failure();
812  }
813  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
814  parser.getBuilder().getArrayAttr(interfaceVars));
815  return success();
816 }
817 
819  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
820  printer.printSymbolName(getFn());
821  auto interfaceVars = getInterface().getValue();
822  if (!interfaceVars.empty())
823  printer << ", " << llvm::interleaved(interfaceVars);
824 }
825 
826 LogicalResult spirv::EntryPointOp::verify() {
827  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
828  // verification.
829  return success();
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // spirv.ExecutionMode
834 //===----------------------------------------------------------------------===//
835 
836 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
837  spirv::FuncOp function,
838  spirv::ExecutionMode executionMode,
839  ArrayRef<int32_t> params) {
840  build(builder, state, SymbolRefAttr::get(function),
841  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
842  builder.getI32ArrayAttr(params));
843 }
844 
845 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
846  OperationState &result) {
847  spirv::ExecutionMode execMode;
848  Attribute fn;
849  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
850  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
851  return failure();
852  }
853 
855  Type i32Type = parser.getBuilder().getIntegerType(32);
856  while (!parser.parseOptionalComma()) {
857  NamedAttrList attr;
858  Attribute value;
859  if (parser.parseAttribute(value, i32Type, "value", attr)) {
860  return failure();
861  }
862  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
863  }
864  StringRef valuesAttrName =
865  spirv::ExecutionModeOp::getValuesAttrName(result.name);
866  result.addAttribute(valuesAttrName,
867  parser.getBuilder().getI32ArrayAttr(values));
868  return success();
869 }
870 
872  printer << " ";
873  printer.printSymbolName(getFn());
874  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
875  ArrayAttr values = this->getValues();
876  if (!values.empty())
877  printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // spirv.func
882 //===----------------------------------------------------------------------===//
883 
884 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
886  SmallVector<DictionaryAttr> resultAttrs;
887  SmallVector<Type> resultTypes;
888  auto &builder = parser.getBuilder();
889 
890  // Parse the name as a symbol.
891  StringAttr nameAttr;
892  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
893  result.attributes))
894  return failure();
895 
896  // Parse the function signature.
897  bool isVariadic = false;
899  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
900  resultAttrs))
901  return failure();
902 
903  SmallVector<Type> argTypes;
904  for (auto &arg : entryArgs)
905  argTypes.push_back(arg.type);
906  auto fnType = builder.getFunctionType(argTypes, resultTypes);
907  result.addAttribute(getFunctionTypeAttrName(result.name),
908  TypeAttr::get(fnType));
909 
910  // Parse the optional function control keyword.
911  spirv::FunctionControl fnControl;
912  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
913  return failure();
914 
915  // If additional attributes are present, parse them.
916  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
917  return failure();
918 
919  // Add the attributes to the function arguments.
920  assert(resultAttrs.size() == resultTypes.size());
922  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
923  getResAttrsAttrName(result.name));
924 
925  // Parse the optional function body.
926  auto *body = result.addRegion();
927  OptionalParseResult parseResult =
928  parser.parseOptionalRegion(*body, entryArgs);
929  return failure(parseResult.has_value() && failed(*parseResult));
930 }
931 
932 void spirv::FuncOp::print(OpAsmPrinter &printer) {
933  // Print function name, signature, and control.
934  printer << " ";
935  printer.printSymbolName(getSymName());
936  auto fnType = getFunctionType();
938  printer, *this, fnType.getInputs(),
939  /*isVariadic=*/false, fnType.getResults());
940  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
941  << "\"";
943  printer, *this,
944  {spirv::attributeName<spirv::FunctionControl>(),
945  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
946  getFunctionControlAttrName()});
947 
948  // Print the body if this is not an external function.
949  Region &body = this->getBody();
950  if (!body.empty()) {
951  printer << ' ';
952  printer.printRegion(body, /*printEntryBlockArgs=*/false,
953  /*printBlockTerminators=*/true);
954  }
955 }
956 
957 LogicalResult spirv::FuncOp::verifyType() {
958  FunctionType fnType = getFunctionType();
959  if (fnType.getNumResults() > 1)
960  return emitOpError("cannot have more than one result");
961 
962  auto hasDecorationAttr = [&](spirv::Decoration decoration,
963  unsigned argIndex) {
964  auto func = llvm::cast<FunctionOpInterface>(getOperation());
965  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
966  if (argAttr.getName() != spirv::DecorationAttr::name)
967  continue;
968  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
969  return decAttr.getValue() == decoration;
970  }
971  return false;
972  };
973 
974  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
975  Type param = fnType.getInputs()[i];
976  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
977  if (!inputPtrType)
978  continue;
979 
980  auto pointeePtrType =
981  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
982  if (pointeePtrType) {
983  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
984  // > If an OpFunctionParameter is a pointer (or contains a pointer)
985  // > and the type it points to is a pointer in the PhysicalStorageBuffer
986  // > storage class, the function parameter must be decorated with exactly
987  // > one of AliasedPointer or RestrictPointer.
988  if (pointeePtrType.getStorageClass() !=
989  spirv::StorageClass::PhysicalStorageBuffer)
990  continue;
991 
992  bool hasAliasedPtr =
993  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
994  bool hasRestrictPtr =
995  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
996  if (!hasAliasedPtr && !hasRestrictPtr)
997  return emitOpError()
998  << "with a pointer points to a physical buffer pointer must "
999  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1000  continue;
1001  }
1002  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1003  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1004  // > the PhysicalStorageBuffer storage class, the function parameter must
1005  // > be decorated with exactly one of Aliased or Restrict.
1006  if (auto pointeeArrayType =
1007  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1008  pointeePtrType =
1009  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1010  } else {
1011  pointeePtrType = inputPtrType;
1012  }
1013 
1014  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1015  spirv::StorageClass::PhysicalStorageBuffer)
1016  continue;
1017 
1018  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1019  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1020  if (!hasAliased && !hasRestrict)
1021  return emitOpError() << "with physical buffer pointer must be decorated "
1022  "either 'Aliased' or 'Restrict'";
1023  }
1024 
1025  return success();
1026 }
1027 
1028 LogicalResult spirv::FuncOp::verifyBody() {
1029  FunctionType fnType = getFunctionType();
1030  if (!isExternal()) {
1031  Block &entryBlock = front();
1032 
1033  unsigned numArguments = this->getNumArguments();
1034  if (entryBlock.getNumArguments() != numArguments)
1035  return emitOpError("entry block must have ")
1036  << numArguments << " arguments to match function signature";
1037 
1038  for (auto [index, fnArgType, blockArgType] :
1039  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1040  if (blockArgType != fnArgType) {
1041  return emitOpError("type of entry block argument #")
1042  << index << '(' << blockArgType
1043  << ") must match the type of the corresponding argument in "
1044  << "function signature(" << fnArgType << ')';
1045  }
1046  }
1047  }
1048 
1049  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1050  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1051  if (fnType.getNumResults() != 0)
1052  return retOp.emitOpError("cannot be used in functions returning value");
1053  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1054  if (fnType.getNumResults() != 1)
1055  return retOp.emitOpError(
1056  "returns 1 value but enclosing function requires ")
1057  << fnType.getNumResults() << " results";
1058 
1059  auto retOperandType = retOp.getValue().getType();
1060  auto fnResultType = fnType.getResult(0);
1061  if (retOperandType != fnResultType)
1062  return retOp.emitOpError(" return value's type (")
1063  << retOperandType << ") mismatch with function's result type ("
1064  << fnResultType << ")";
1065  }
1066  return WalkResult::advance();
1067  });
1068 
1069  // TODO: verify other bits like linkage type.
1070 
1071  return failure(walkResult.wasInterrupted());
1072 }
1073 
1074 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1075  StringRef name, FunctionType type,
1076  spirv::FunctionControl control,
1077  ArrayRef<NamedAttribute> attrs) {
1078  state.addAttribute(SymbolTable::getSymbolAttrName(),
1079  builder.getStringAttr(name));
1080  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1081  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1082  builder.getAttr<spirv::FunctionControlAttr>(control));
1083  state.attributes.append(attrs.begin(), attrs.end());
1084  state.addRegion();
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // spirv.GLFClampOp
1089 //===----------------------------------------------------------------------===//
1090 
1091 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1092  OperationState &result) {
1093  return parseOneResultSameOperandTypeOp(parser, result);
1094 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // spirv.GLUClampOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1102  OperationState &result) {
1103  return parseOneResultSameOperandTypeOp(parser, result);
1104 }
1106 
1107 //===----------------------------------------------------------------------===//
1108 // spirv.GLSClampOp
1109 //===----------------------------------------------------------------------===//
1110 
1111 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1112  OperationState &result) {
1113  return parseOneResultSameOperandTypeOp(parser, result);
1114 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // spirv.GLFmaOp
1119 //===----------------------------------------------------------------------===//
1120 
1121 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1122  return parseOneResultSameOperandTypeOp(parser, result);
1123 }
1124 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // spirv.GlobalVariable
1128 //===----------------------------------------------------------------------===//
1129 
1130 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1131  Type type, StringRef name,
1132  unsigned descriptorSet, unsigned binding) {
1133  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1134  state.addAttribute(
1135  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1136  builder.getI32IntegerAttr(descriptorSet));
1137  state.addAttribute(
1138  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1139  builder.getI32IntegerAttr(binding));
1140 }
1141 
1142 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1143  Type type, StringRef name,
1144  spirv::BuiltIn builtin) {
1145  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1146  state.addAttribute(
1147  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1148  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1149 }
1150 
1151 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1152  OperationState &result) {
1153  // Parse variable name.
1154  StringAttr nameAttr;
1155  StringRef initializerAttrName =
1156  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1157  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1158  result.attributes)) {
1159  return failure();
1160  }
1161 
1162  // Parse optional initializer
1163  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1164  FlatSymbolRefAttr initSymbol;
1165  if (parser.parseLParen() ||
1166  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1167  result.attributes) ||
1168  parser.parseRParen())
1169  return failure();
1170  }
1171 
1172  if (parseVariableDecorations(parser, result)) {
1173  return failure();
1174  }
1175 
1176  Type type;
1177  StringRef typeAttrName =
1178  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1179  auto loc = parser.getCurrentLocation();
1180  if (parser.parseColonType(type)) {
1181  return failure();
1182  }
1183  if (!llvm::isa<spirv::PointerType>(type)) {
1184  return parser.emitError(loc, "expected spirv.ptr type");
1185  }
1186  result.addAttribute(typeAttrName, TypeAttr::get(type));
1187 
1188  return success();
1189 }
1190 
1192  SmallVector<StringRef, 4> elidedAttrs{
1193  spirv::attributeName<spirv::StorageClass>()};
1194 
1195  // Print variable name.
1196  printer << ' ';
1197  printer.printSymbolName(getSymName());
1198  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1199 
1200  StringRef initializerAttrName = this->getInitializerAttrName();
1201  // Print optional initializer
1202  if (auto initializer = this->getInitializer()) {
1203  printer << " " << initializerAttrName << '(';
1204  printer.printSymbolName(*initializer);
1205  printer << ')';
1206  elidedAttrs.push_back(initializerAttrName);
1207  }
1208 
1209  StringRef typeAttrName = this->getTypeAttrName();
1210  elidedAttrs.push_back(typeAttrName);
1211  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1212  printer << " : " << getType();
1213 }
1214 
1215 LogicalResult spirv::GlobalVariableOp::verify() {
1216  if (!llvm::isa<spirv::PointerType>(getType()))
1217  return emitOpError("result must be of a !spv.ptr type");
1218 
1219  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1220  // object. It cannot be Generic. It must be the same as the Storage Class
1221  // operand of the Result Type."
1222  // Also, Function storage class is reserved by spirv.Variable.
1223  auto storageClass = this->storageClass();
1224  if (storageClass == spirv::StorageClass::Generic ||
1225  storageClass == spirv::StorageClass::Function) {
1226  return emitOpError("storage class cannot be '")
1227  << stringifyStorageClass(storageClass) << "'";
1228  }
1229 
1230  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1231  this->getInitializerAttrName())) {
1233  (*this)->getParentOp(), init.getAttr());
1234  // TODO: Currently only variable initialization with specialization
1235  // constants and other variables is supported. They could be normal
1236  // constants in the module scope as well.
1237  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1238  spirv::SpecConstantCompositeOp>(initOp)) {
1239  return emitOpError("initializer must be result of a "
1240  "spirv.SpecConstant or spirv.GlobalVariable or "
1241  "spirv.SpecConstantCompositeOp op");
1242  }
1243  }
1244 
1245  return success();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // spirv.INTEL.SubgroupBlockRead
1250 //===----------------------------------------------------------------------===//
1251 
1253  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1254  return failure();
1255 
1256  return success();
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // spirv.INTEL.SubgroupBlockWrite
1261 //===----------------------------------------------------------------------===//
1262 
1264  OperationState &result) {
1265  // Parse the storage class specification
1266  spirv::StorageClass storageClass;
1268  auto loc = parser.getCurrentLocation();
1269  Type elementType;
1270  if (parseEnumStrAttr(storageClass, parser) ||
1271  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1272  parser.parseType(elementType)) {
1273  return failure();
1274  }
1275 
1276  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1277  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1278  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1279 
1280  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1281  result.operands)) {
1282  return failure();
1283  }
1284  return success();
1285 }
1286 
1288  printer << " " << getPtr() << ", " << getValue() << " : "
1289  << getValue().getType();
1290 }
1291 
1293  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1294  return failure();
1295 
1296  return success();
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // spirv.IAddCarryOp
1301 //===----------------------------------------------------------------------===//
1302 
1303 LogicalResult spirv::IAddCarryOp::verify() {
1305 }
1306 
1307 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1308  OperationState &result) {
1310 }
1311 
1312 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1313  ::printArithmeticExtendedBinaryOp(*this, printer);
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // spirv.ISubBorrowOp
1318 //===----------------------------------------------------------------------===//
1319 
1320 LogicalResult spirv::ISubBorrowOp::verify() {
1322 }
1323 
1324 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1325  OperationState &result) {
1327 }
1328 
1330  ::printArithmeticExtendedBinaryOp(*this, printer);
1331 }
1332 
1333 //===----------------------------------------------------------------------===//
1334 // spirv.SMulExtended
1335 //===----------------------------------------------------------------------===//
1336 
1337 LogicalResult spirv::SMulExtendedOp::verify() {
1339 }
1340 
1341 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1342  OperationState &result) {
1344 }
1345 
1347  ::printArithmeticExtendedBinaryOp(*this, printer);
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // spirv.UMulExtended
1352 //===----------------------------------------------------------------------===//
1353 
1354 LogicalResult spirv::UMulExtendedOp::verify() {
1356 }
1357 
1358 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1359  OperationState &result) {
1361 }
1362 
1364  ::printArithmeticExtendedBinaryOp(*this, printer);
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // spirv.MemoryBarrierOp
1369 //===----------------------------------------------------------------------===//
1370 
1371 LogicalResult spirv::MemoryBarrierOp::verify() {
1372  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1373 }
1374 
1375 //===----------------------------------------------------------------------===//
1376 // spirv.module
1377 //===----------------------------------------------------------------------===//
1378 
1379 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1380  std::optional<StringRef> name) {
1381  OpBuilder::InsertionGuard guard(builder);
1382  builder.createBlock(state.addRegion());
1383  if (name) {
1384  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1385  builder.getStringAttr(*name));
1386  }
1387 }
1388 
1389 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1390  spirv::AddressingModel addressingModel,
1391  spirv::MemoryModel memoryModel,
1392  std::optional<VerCapExtAttr> vceTriple,
1393  std::optional<StringRef> name) {
1394  state.addAttribute(
1395  "addressing_model",
1396  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1397  state.addAttribute("memory_model",
1398  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1399  OpBuilder::InsertionGuard guard(builder);
1400  builder.createBlock(state.addRegion());
1401  if (vceTriple)
1402  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1403  if (name)
1404  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1405  builder.getStringAttr(*name));
1406 }
1407 
1408 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1409  OperationState &result) {
1410  Region *body = result.addRegion();
1411 
1412  // If the name is present, parse it.
1413  StringAttr nameAttr;
1414  (void)parser.parseOptionalSymbolName(
1415  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1416 
1417  // Parse attributes
1418  spirv::AddressingModel addrModel;
1419  spirv::MemoryModel memoryModel;
1420  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1421  result) ||
1422  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1423  result))
1424  return failure();
1425 
1426  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1427  spirv::VerCapExtAttr vceTriple;
1428  if (parser.parseAttribute(vceTriple,
1429  spirv::ModuleOp::getVCETripleAttrName(),
1430  result.attributes))
1431  return failure();
1432  }
1433 
1434  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1435  parser.parseRegion(*body, /*arguments=*/{}))
1436  return failure();
1437 
1438  // Make sure we have at least one block.
1439  if (body->empty())
1440  body->push_back(new Block());
1441 
1442  return success();
1443 }
1444 
1445 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1446  if (std::optional<StringRef> name = getName()) {
1447  printer << ' ';
1448  printer.printSymbolName(*name);
1449  }
1450 
1451  SmallVector<StringRef, 2> elidedAttrs;
1452 
1453  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1454  << spirv::stringifyMemoryModel(getMemoryModel());
1455  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1456  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1457  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1459 
1460  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1461  printer << " requires " << *triple;
1462  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1463  }
1464 
1465  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1466  printer << ' ';
1467  printer.printRegion(getRegion());
1468 }
1469 
1470 LogicalResult spirv::ModuleOp::verifyRegions() {
1471  Dialect *dialect = (*this)->getDialect();
1473  entryPoints;
1474  mlir::SymbolTable table(*this);
1475 
1476  for (auto &op : *getBody()) {
1477  if (op.getDialect() != dialect)
1478  return op.emitError("'spirv.module' can only contain spirv.* ops");
1479 
1480  // For EntryPoint op, check that the function and execution model is not
1481  // duplicated in EntryPointOps. Also verify that the interface specified
1482  // comes from globalVariables here to make this check cheaper.
1483  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1484  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1485  if (!funcOp) {
1486  return entryPointOp.emitError("function '")
1487  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1488  }
1489  if (auto interface = entryPointOp.getInterface()) {
1490  for (Attribute varRef : interface) {
1491  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1492  if (!varSymRef) {
1493  return entryPointOp.emitError(
1494  "expected symbol reference for interface "
1495  "specification instead of '")
1496  << varRef;
1497  }
1498  auto variableOp =
1499  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1500  if (!variableOp) {
1501  return entryPointOp.emitError("expected spirv.GlobalVariable "
1502  "symbol reference instead of'")
1503  << varSymRef << "'";
1504  }
1505  }
1506  }
1507 
1508  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1509  funcOp, entryPointOp.getExecutionModel());
1510  if (!entryPoints.try_emplace(key, entryPointOp).second)
1511  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1512  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1513  // If the function is external and does not have 'Import'
1514  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1515  // LinkageAttributes is used to import external functions.
1516  auto linkageAttr = funcOp.getLinkageAttributes();
1517  auto hasImportLinkage =
1518  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1519  spirv::LinkageType::Import);
1520  if (funcOp.isExternal() && !hasImportLinkage)
1521  return op.emitError(
1522  "'spirv.module' cannot contain external functions "
1523  "without 'Import' linkage_attributes (LinkageAttributes)");
1524 
1525  // TODO: move this check to spirv.func.
1526  for (auto &block : funcOp)
1527  for (auto &op : block) {
1528  if (op.getDialect() != dialect)
1529  return op.emitError(
1530  "functions in 'spirv.module' can only contain spirv.* ops");
1531  }
1532  }
1533  }
1534 
1535  return success();
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // spirv.mlir.referenceof
1540 //===----------------------------------------------------------------------===//
1541 
1542 LogicalResult spirv::ReferenceOfOp::verify() {
1543  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1544  (*this)->getParentOp(), getSpecConstAttr());
1545  Type constType;
1546 
1547  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1548  if (specConstOp)
1549  constType = specConstOp.getDefaultValue().getType();
1550 
1551  auto specConstCompositeOp =
1552  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1553  if (specConstCompositeOp)
1554  constType = specConstCompositeOp.getType();
1555 
1556  if (!specConstOp && !specConstCompositeOp)
1557  return emitOpError(
1558  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1559 
1560  if (getReference().getType() != constType)
1561  return emitOpError("result type mismatch with the referenced "
1562  "specialization constant's type");
1563 
1564  return success();
1565 }
1566 
1567 //===----------------------------------------------------------------------===//
1568 // spirv.SpecConstant
1569 //===----------------------------------------------------------------------===//
1570 
1571 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1572  OperationState &result) {
1573  StringAttr nameAttr;
1574  Attribute valueAttr;
1575  StringRef defaultValueAttrName =
1576  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1577 
1578  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1579  result.attributes))
1580  return failure();
1581 
1582  // Parse optional spec_id.
1583  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1584  IntegerAttr specIdAttr;
1585  if (parser.parseLParen() ||
1586  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1587  parser.parseRParen())
1588  return failure();
1589  }
1590 
1591  if (parser.parseEqual() ||
1592  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1593  return failure();
1594 
1595  return success();
1596 }
1597 
1599  printer << ' ';
1600  printer.printSymbolName(getSymName());
1601  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1602  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1603  printer << " = " << getDefaultValue();
1604 }
1605 
1606 LogicalResult spirv::SpecConstantOp::verify() {
1607  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1608  if (specID.getValue().isNegative())
1609  return emitOpError("SpecId cannot be negative");
1610 
1611  auto value = getDefaultValue();
1612  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1613  // Make sure bitwidth is allowed.
1614  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1615  return emitOpError("default value bitwidth disallowed");
1616  return success();
1617  }
1618  return emitOpError(
1619  "default value can only be a bool, integer, or float scalar");
1620 }
1621 
1622 //===----------------------------------------------------------------------===//
1623 // spirv.VectorShuffle
1624 //===----------------------------------------------------------------------===//
1625 
1626 LogicalResult spirv::VectorShuffleOp::verify() {
1627  VectorType resultType = llvm::cast<VectorType>(getType());
1628 
1629  size_t numResultElements = resultType.getNumElements();
1630  if (numResultElements != getComponents().size())
1631  return emitOpError("result type element count (")
1632  << numResultElements
1633  << ") mismatch with the number of component selectors ("
1634  << getComponents().size() << ")";
1635 
1636  size_t totalSrcElements =
1637  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1638  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1639 
1640  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1641  uint32_t index = selector.getZExtValue();
1642  if (index >= totalSrcElements &&
1643  index != std::numeric_limits<uint32_t>().max())
1644  return emitOpError("component selector ")
1645  << index << " out of range: expected to be in [0, "
1646  << totalSrcElements << ") or 0xffffffff";
1647  }
1648  return success();
1649 }
1650 
1651 //===----------------------------------------------------------------------===//
1652 // spirv.MatrixTimesScalar
1653 //===----------------------------------------------------------------------===//
1654 
1655 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1656  Type elementType =
1657  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1659  [](auto matrixType) { return matrixType.getElementType(); })
1660  .Default([](Type) { return nullptr; });
1661 
1662  assert(elementType && "Unhandled type");
1663 
1664  // Check that the scalar type is the same as the matrix element type.
1665  if (getScalar().getType() != elementType)
1666  return emitOpError("input matrix components' type and scaling value must "
1667  "have the same type");
1668 
1669  return success();
1670 }
1671 
1672 //===----------------------------------------------------------------------===//
1673 // spirv.Transpose
1674 //===----------------------------------------------------------------------===//
1675 
1676 LogicalResult spirv::TransposeOp::verify() {
1677  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1678  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1679 
1680  // Verify that the input and output matrices have correct shapes.
1681  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1682  return emitError("input matrix rows count must be equal to "
1683  "output matrix columns count");
1684 
1685  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1686  return emitError("input matrix columns count must be equal to "
1687  "output matrix rows count");
1688 
1689  // Verify that the input and output matrices have the same component type
1690  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1691  return emitError("input and output matrices must have the same "
1692  "component type");
1693 
1694  return success();
1695 }
1696 
1697 //===----------------------------------------------------------------------===//
1698 // spirv.MatrixTimesVector
1699 //===----------------------------------------------------------------------===//
1700 
1701 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1702  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1703  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1704  auto resultType = llvm::cast<VectorType>(getType());
1705 
1706  if (matrixType.getNumColumns() != vectorType.getNumElements())
1707  return emitOpError("matrix columns (")
1708  << matrixType.getNumColumns() << ") must match vector operand size ("
1709  << vectorType.getNumElements() << ")";
1710 
1711  if (resultType.getNumElements() != matrixType.getNumRows())
1712  return emitOpError("result size (")
1713  << resultType.getNumElements() << ") must match the matrix rows ("
1714  << matrixType.getNumRows() << ")";
1715 
1716  if (matrixType.getElementType() != resultType.getElementType())
1717  return emitOpError("matrix and result element types must match");
1718 
1719  return success();
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // spirv.VectorTimesMatrix
1724 //===----------------------------------------------------------------------===//
1725 
1726 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1727  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1728  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1729  auto resultType = llvm::cast<VectorType>(getType());
1730 
1731  if (matrixType.getNumRows() != vectorType.getNumElements())
1732  return emitOpError("number of components in vector must equal the number "
1733  "of components in each column in matrix");
1734 
1735  if (resultType.getNumElements() != matrixType.getNumColumns())
1736  return emitOpError("number of columns in matrix must equal the number of "
1737  "components in result");
1738 
1739  if (matrixType.getElementType() != resultType.getElementType())
1740  return emitOpError("matrix must be a matrix with the same component type "
1741  "as the component type in result");
1742 
1743  return success();
1744 }
1745 
1746 //===----------------------------------------------------------------------===//
1747 // spirv.MatrixTimesMatrix
1748 //===----------------------------------------------------------------------===//
1749 
1750 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1751  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1752  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1753  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1754 
1755  // left matrix columns' count and right matrix rows' count must be equal
1756  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1757  return emitError("left matrix columns' count must be equal to "
1758  "the right matrix rows' count");
1759 
1760  // right and result matrices columns' count must be the same
1761  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1762  return emitError(
1763  "right and result matrices must have equal columns' count");
1764 
1765  // right and result matrices component type must be the same
1766  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1767  return emitError("right and result matrices' component type must"
1768  " be the same");
1769 
1770  // left and result matrices component type must be the same
1771  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1772  return emitError("left and result matrices' component type"
1773  " must be the same");
1774 
1775  // left and result matrices rows count must be the same
1776  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1777  return emitError("left and result matrices must have equal rows' count");
1778 
1779  return success();
1780 }
1781 
1782 //===----------------------------------------------------------------------===//
1783 // spirv.SpecConstantComposite
1784 //===----------------------------------------------------------------------===//
1785 
1787  OperationState &result) {
1788 
1789  StringAttr compositeName;
1790  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1791  result.attributes))
1792  return failure();
1793 
1794  if (parser.parseLParen())
1795  return failure();
1796 
1797  SmallVector<Attribute, 4> constituents;
1798 
1799  do {
1800  // The name of the constituent attribute isn't important
1801  const char *attrName = "spec_const";
1802  FlatSymbolRefAttr specConstRef;
1803  NamedAttrList attrs;
1804 
1805  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1806  return failure();
1807 
1808  constituents.push_back(specConstRef);
1809  } while (!parser.parseOptionalComma());
1810 
1811  if (parser.parseRParen())
1812  return failure();
1813 
1814  StringAttr compositeSpecConstituentsName =
1815  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1816  result.addAttribute(compositeSpecConstituentsName,
1817  parser.getBuilder().getArrayAttr(constituents));
1818 
1819  Type type;
1820  if (parser.parseColonType(type))
1821  return failure();
1822 
1823  StringAttr typeAttrName =
1824  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1825  result.addAttribute(typeAttrName, TypeAttr::get(type));
1826 
1827  return success();
1828 }
1829 
1831  printer << " ";
1832  printer.printSymbolName(getSymName());
1833  printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1834  << ") : " << getType();
1835 }
1836 
1838  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1839  auto constituents = this->getConstituents().getValue();
1840 
1841  if (!cType)
1842  return emitError("result type must be a composite type, but provided ")
1843  << getType();
1844 
1845  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1846  return emitError("unsupported composite type ") << cType;
1847  if (constituents.size() != cType.getNumElements())
1848  return emitError("has incorrect number of operands: expected ")
1849  << cType.getNumElements() << ", but provided "
1850  << constituents.size();
1851 
1852  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1853  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1854 
1855  auto constituentSpecConstOp =
1856  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1857  (*this)->getParentOp(), constituent.getAttr()));
1858 
1859  if (constituentSpecConstOp.getDefaultValue().getType() !=
1860  cType.getElementType(index))
1861  return emitError("has incorrect types of operands: expected ")
1862  << cType.getElementType(index) << ", but provided "
1863  << constituentSpecConstOp.getDefaultValue().getType();
1864  }
1865 
1866  return success();
1867 }
1868 
1869 //===----------------------------------------------------------------------===//
1870 // spirv.SpecConstantOperation
1871 //===----------------------------------------------------------------------===//
1872 
1874  OperationState &result) {
1875  Region *body = result.addRegion();
1876 
1877  if (parser.parseKeyword("wraps"))
1878  return failure();
1879 
1880  body->push_back(new Block);
1881  Block &block = body->back();
1882  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1883 
1884  if (!wrappedOp)
1885  return failure();
1886 
1887  OpBuilder builder(parser.getContext());
1888  builder.setInsertionPointToEnd(&block);
1889  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1890  result.location = wrappedOp->getLoc();
1891 
1892  result.addTypes(wrappedOp->getResult(0).getType());
1893 
1894  if (parser.parseOptionalAttrDict(result.attributes))
1895  return failure();
1896 
1897  return success();
1898 }
1899 
1901  printer << " wraps ";
1902  printer.printGenericOp(&getBody().front().front());
1903 }
1904 
1905 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1906  Block &block = getRegion().getBlocks().front();
1907 
1908  if (block.getOperations().size() != 2)
1909  return emitOpError("expected exactly 2 nested ops");
1910 
1911  Operation &enclosedOp = block.getOperations().front();
1912 
1914  return emitOpError("invalid enclosed op");
1915 
1916  for (auto operand : enclosedOp.getOperands())
1917  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1918  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1919  return emitOpError(
1920  "invalid operand, must be defined by a constant operation");
1921 
1922  return success();
1923 }
1924 
1925 //===----------------------------------------------------------------------===//
1926 // spirv.GL.FrexpStruct
1927 //===----------------------------------------------------------------------===//
1928 
1929 LogicalResult spirv::GLFrexpStructOp::verify() {
1930  spirv::StructType structTy =
1931  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1932 
1933  if (structTy.getNumElements() != 2)
1934  return emitError("result type must be a struct type with two memebers");
1935 
1936  Type significandTy = structTy.getElementType(0);
1937  Type exponentTy = structTy.getElementType(1);
1938  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1939  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1940 
1941  Type operandTy = getOperand().getType();
1942  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1943  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1944 
1945  if (significandTy != operandTy)
1946  return emitError("member zero of the resulting struct type must be the "
1947  "same type as the operand");
1948 
1949  if (exponentVecTy) {
1950  IntegerType componentIntTy =
1951  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1952  if (!componentIntTy || componentIntTy.getWidth() != 32)
1953  return emitError("member one of the resulting struct type must"
1954  "be a scalar or vector of 32 bit integer type");
1955  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1956  return emitError("member one of the resulting struct type "
1957  "must be a scalar or vector of 32 bit integer type");
1958  }
1959 
1960  // Check that the two member types have the same number of components
1961  if (operandVecTy && exponentVecTy &&
1962  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1963  return success();
1964 
1965  if (operandFTy && exponentIntTy)
1966  return success();
1967 
1968  return emitError("member one of the resulting struct type must have the same "
1969  "number of components as the operand type");
1970 }
1971 
1972 //===----------------------------------------------------------------------===//
1973 // spirv.GL.Ldexp
1974 //===----------------------------------------------------------------------===//
1975 
1976 LogicalResult spirv::GLLdexpOp::verify() {
1977  Type significandType = getX().getType();
1978  Type exponentType = getExp().getType();
1979 
1980  if (llvm::isa<FloatType>(significandType) !=
1981  llvm::isa<IntegerType>(exponentType))
1982  return emitOpError("operands must both be scalars or vectors");
1983 
1984  auto getNumElements = [](Type type) -> unsigned {
1985  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1986  return vectorType.getNumElements();
1987  return 1;
1988  };
1989 
1990  if (getNumElements(significandType) != getNumElements(exponentType))
1991  return emitOpError("operands must have the same number of elements");
1992 
1993  return success();
1994 }
1995 
1996 //===----------------------------------------------------------------------===//
1997 // spirv.ShiftLeftLogicalOp
1998 //===----------------------------------------------------------------------===//
1999 
2000 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2001  return verifyShiftOp(*this);
2002 }
2003 
2004 //===----------------------------------------------------------------------===//
2005 // spirv.ShiftRightArithmeticOp
2006 //===----------------------------------------------------------------------===//
2007 
2008 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2009  return verifyShiftOp(*this);
2010 }
2011 
2012 //===----------------------------------------------------------------------===//
2013 // spirv.ShiftRightLogicalOp
2014 //===----------------------------------------------------------------------===//
2015 
2016 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2017  return verifyShiftOp(*this);
2018 }
2019 
2020 //===----------------------------------------------------------------------===//
2021 // spirv.VectorTimesScalarOp
2022 //===----------------------------------------------------------------------===//
2023 
2024 LogicalResult spirv::VectorTimesScalarOp::verify() {
2025  if (getVector().getType() != getType())
2026  return emitOpError("vector operand and result type mismatch");
2027  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2028  if (getScalar().getType() != scalarType)
2029  return emitOpError("scalar operand and result element type match");
2030  return success();
2031 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:269
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:565
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:122
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:255
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:301
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:171
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:151
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:293
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:271
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:450
SPIR-V struct type.
Definition: SPIRVTypes.h:295
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:71
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.