MLIR  21.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/InterleavedRange.h"
39 #include <cassert>
40 #include <numeric>
41 #include <optional>
42 #include <type_traits>
43 
44 using namespace mlir;
45 using namespace mlir::spirv::AttrNames;
46 
47 //===----------------------------------------------------------------------===//
48 // Common utility functions
49 //===----------------------------------------------------------------------===//
50 
51 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
52  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53  if (!constOp) {
54  return failure();
55  }
56  auto valueAttr = constOp.getValue();
57  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58  if (!integerValueAttr) {
59  return failure();
60  }
61 
62  if (integerValueAttr.getType().isSignlessInteger())
63  value = integerValueAttr.getInt();
64  else
65  value = integerValueAttr.getSInt();
66 
67  return success();
68 }
69 
70 LogicalResult
72  spirv::MemorySemantics memorySemantics) {
73  // According to the SPIR-V specification:
74  // "Despite being a mask and allowing multiple bits to be combined, it is
75  // invalid for more than one of these four bits to be set: Acquire, Release,
76  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77  // Release semantics is done by setting the AcquireRelease bit, not by setting
78  // two bits."
79  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80  spirv::MemorySemantics::Release |
81  spirv::MemorySemantics::AcquireRelease |
82  spirv::MemorySemantics::SequentiallyConsistent;
83 
84  auto bitCount =
85  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86  if (bitCount > 1) {
87  return op->emitError(
88  "expected at most one of these four memory constraints "
89  "to be set: `Acquire`, `Release`,"
90  "`AcquireRelease` or `SequentiallyConsistent`");
91  }
92  return success();
93 }
94 
96  SmallVectorImpl<StringRef> &elidedAttrs) {
97  // Print optional descriptor binding
98  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::DescriptorSet));
100  auto bindingName = llvm::convertToSnakeFromCamelCase(
101  stringifyDecoration(spirv::Decoration::Binding));
102  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104  if (descriptorSet && binding) {
105  elidedAttrs.push_back(descriptorSetName);
106  elidedAttrs.push_back(bindingName);
107  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108  << ")";
109  }
110 
111  // Print BuiltIn attribute if present
112  auto builtInName = llvm::convertToSnakeFromCamelCase(
113  stringifyDecoration(spirv::Decoration::BuiltIn));
114  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116  elidedAttrs.push_back(builtInName);
117  }
118 
119  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
120 }
121 
122 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
123  OperationState &result) {
125  Type type;
126  // If the operand list is in-between parentheses, then we have a generic form.
127  // (see the fallback in `printOneResultOp`).
128  SMLoc loc = parser.getCurrentLocation();
129  if (!parser.parseOptionalLParen()) {
130  if (parser.parseOperandList(ops) || parser.parseRParen() ||
131  parser.parseOptionalAttrDict(result.attributes) ||
132  parser.parseColon() || parser.parseType(type))
133  return failure();
134  auto fnType = llvm::dyn_cast<FunctionType>(type);
135  if (!fnType) {
136  parser.emitError(loc, "expected function type");
137  return failure();
138  }
139  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140  return failure();
141  result.addTypes(fnType.getResults());
142  return success();
143  }
144  return failure(parser.parseOperandList(ops) ||
145  parser.parseOptionalAttrDict(result.attributes) ||
146  parser.parseColonType(type) ||
147  parser.resolveOperands(ops, type, result.operands) ||
148  parser.addTypeToList(type, result.types));
149 }
150 
152  assert(op->getNumResults() == 1 && "op should have one result");
153 
154  // If not all the operand and result types are the same, just use the
155  // generic assembly form to avoid omitting information in printing.
156  auto resultType = op->getResult(0).getType();
157  if (llvm::any_of(op->getOperandTypes(),
158  [&](Type type) { return type != resultType; })) {
159  p.printGenericOp(op, /*printOpName=*/false);
160  return;
161  }
162 
163  p << ' ';
164  p.printOperands(op->getOperands());
166  // Now we can output only one type for all operands and the result.
167  p << " : " << resultType;
168 }
169 
170 template <typename BlockReadWriteOpTy>
171 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
172  Value ptr, Value val) {
173  auto valType = val.getType();
174  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
175  valType = valVecTy.getElementType();
176 
177  if (valType !=
178  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
179  return op.emitOpError("mismatch in result type and pointer type");
180  }
181  return success();
182 }
183 
184 /// Walks the given type hierarchy with the given indices, potentially down
185 /// to component granularity, to select an element type. Returns null type and
186 /// emits errors with the given loc on failure.
187 static Type
189  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
190  if (indices.empty()) {
191  emitErrorFn("expected at least one index for spirv.CompositeExtract");
192  return nullptr;
193  }
194 
195  for (auto index : indices) {
196  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
197  if (cType.hasCompileTimeKnownNumElements() &&
198  (index < 0 ||
199  static_cast<uint64_t>(index) >= cType.getNumElements())) {
200  emitErrorFn("index ") << index << " out of bounds for " << type;
201  return nullptr;
202  }
203  type = cType.getElementType(index);
204  } else {
205  emitErrorFn("cannot extract from non-composite type ")
206  << type << " with index " << index;
207  return nullptr;
208  }
209  }
210  return type;
211 }
212 
213 static Type
215  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
216  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
217  if (!indicesArrayAttr) {
218  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
219  return nullptr;
220  }
221  if (indicesArrayAttr.empty()) {
222  emitErrorFn("expected at least one index for spirv.CompositeExtract");
223  return nullptr;
224  }
225 
226  SmallVector<int32_t, 2> indexVals;
227  for (auto indexAttr : indicesArrayAttr) {
228  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
229  if (!indexIntAttr) {
230  emitErrorFn("expected an 32-bit integer for index, but found '")
231  << indexAttr << "'";
232  return nullptr;
233  }
234  indexVals.push_back(indexIntAttr.getInt());
235  }
236  return getElementType(type, indexVals, emitErrorFn);
237 }
238 
239 static Type getElementType(Type type, Attribute indices, Location loc) {
240  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
241  return ::mlir::emitError(loc, err);
242  };
243  return getElementType(type, indices, errorFn);
244 }
245 
246 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
247  SMLoc loc) {
248  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
249  return parser.emitError(loc, err);
250  };
251  return getElementType(type, indices, errorFn);
252 }
253 
254 template <typename ExtendedBinaryOp>
255 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
256  auto resultType = llvm::cast<spirv::StructType>(op.getType());
257  if (resultType.getNumElements() != 2)
258  return op.emitOpError("expected result struct type containing two members");
259 
260  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
261  resultType.getElementType(0),
262  resultType.getElementType(1)}))
263  return op.emitOpError(
264  "expected all operand types and struct member types are the same");
265 
266  return success();
267 }
268 
269 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
270  OperationState &result) {
272  if (parser.parseOptionalAttrDict(result.attributes) ||
273  parser.parseOperandList(operands) || parser.parseColon())
274  return failure();
275 
276  Type resultType;
277  SMLoc loc = parser.getCurrentLocation();
278  if (parser.parseType(resultType))
279  return failure();
280 
281  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
282  if (!structType || structType.getNumElements() != 2)
283  return parser.emitError(loc, "expected spirv.struct type with two members");
284 
285  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
286  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
287  return failure();
288 
289  result.addTypes(resultType);
290  return success();
291 }
292 
294  OpAsmPrinter &printer) {
295  printer << ' ';
296  printer.printOptionalAttrDict(op->getAttrs());
297  printer.printOperands(op->getOperands());
298  printer << " : " << op->getResultTypes().front();
299 }
300 
301 static LogicalResult verifyShiftOp(Operation *op) {
302  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
303  return op->emitError("expected the same type for the first operand and "
304  "result, but provided ")
305  << op->getOperand(0).getType() << " and "
306  << op->getResult(0).getType();
307  }
308  return success();
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // spirv.mlir.addressof
313 //===----------------------------------------------------------------------===//
314 
315 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
316  spirv::GlobalVariableOp var) {
317  build(builder, state, var.getType(), SymbolRefAttr::get(var));
318 }
319 
320 LogicalResult spirv::AddressOfOp::verify() {
321  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
322  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
323  getVariableAttr()));
324  if (!varOp) {
325  return emitOpError("expected spirv.GlobalVariable symbol");
326  }
327  if (getPointer().getType() != varOp.getType()) {
328  return emitOpError(
329  "result type mismatch with the referenced global variable's type");
330  }
331  return success();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // spirv.CompositeConstruct
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult spirv::CompositeConstructOp::verify() {
339  operand_range constituents = this->getConstituents();
340 
341  // There are 4 cases with varying verification rules:
342  // 1. Cooperative Matrices (1 constituent)
343  // 2. Structs (1 constituent for each member)
344  // 3. Arrays (1 constituent for each array element)
345  // 4. Vectors (1 constituent (sub-)element for each vector element)
346 
347  auto coopElementType =
350  [](auto coopType) { return coopType.getElementType(); })
351  .Default([](Type) { return nullptr; });
352 
353  // Case 1. -- matrices.
354  if (coopElementType) {
355  if (constituents.size() != 1)
356  return emitOpError("has incorrect number of operands: expected ")
357  << "1, but provided " << constituents.size();
358  if (coopElementType != constituents.front().getType())
359  return emitOpError("operand type mismatch: expected operand type ")
360  << coopElementType << ", but provided "
361  << constituents.front().getType();
362  return success();
363  }
364 
365  // Case 2./3./4. -- number of constituents matches the number of elements.
366  auto cType = llvm::cast<spirv::CompositeType>(getType());
367  if (constituents.size() == cType.getNumElements()) {
368  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
369  if (constituents[index].getType() != cType.getElementType(index)) {
370  return emitOpError("operand type mismatch: expected operand type ")
371  << cType.getElementType(index) << ", but provided "
372  << constituents[index].getType();
373  }
374  }
375  return success();
376  }
377 
378  // Case 4. -- check that all constituents add up tp the expected vector type.
379  auto resultType = llvm::dyn_cast<VectorType>(cType);
380  if (!resultType)
381  return emitOpError(
382  "expected to return a vector or cooperative matrix when the number of "
383  "constituents is less than what the result needs");
384 
385  SmallVector<unsigned> sizes;
386  for (Value component : constituents) {
387  if (!llvm::isa<VectorType>(component.getType()) &&
388  !component.getType().isIntOrFloat())
389  return emitOpError("operand type mismatch: expected operand to have "
390  "a scalar or vector type, but provided ")
391  << component.getType();
392 
393  Type elementType = component.getType();
394  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
395  sizes.push_back(vectorType.getNumElements());
396  elementType = vectorType.getElementType();
397  } else {
398  sizes.push_back(1);
399  }
400 
401  if (elementType != resultType.getElementType())
402  return emitOpError("operand element type mismatch: expected to be ")
403  << resultType.getElementType() << ", but provided " << elementType;
404  }
405  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
406  if (totalCount != cType.getNumElements())
407  return emitOpError("has incorrect number of operands: expected ")
408  << cType.getNumElements() << ", but provided " << totalCount;
409  return success();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // spirv.CompositeExtractOp
414 //===----------------------------------------------------------------------===//
415 
416 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
417  Value composite,
418  ArrayRef<int32_t> indices) {
419  auto indexAttr = builder.getI32ArrayAttr(indices);
420  auto elementType =
421  getElementType(composite.getType(), indexAttr, state.location);
422  if (!elementType) {
423  return;
424  }
425  build(builder, state, elementType, composite, indexAttr);
426 }
427 
429  OperationState &result) {
430  OpAsmParser::UnresolvedOperand compositeInfo;
431  Attribute indicesAttr;
432  StringRef indicesAttrName =
433  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
434  Type compositeType;
435  SMLoc attrLocation;
436 
437  if (parser.parseOperand(compositeInfo) ||
438  parser.getCurrentLocation(&attrLocation) ||
439  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
440  parser.parseColonType(compositeType) ||
441  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
442  return failure();
443  }
444 
445  Type resultType =
446  getElementType(compositeType, indicesAttr, parser, attrLocation);
447  if (!resultType) {
448  return failure();
449  }
450  result.addTypes(resultType);
451  return success();
452 }
453 
455  printer << ' ' << getComposite() << getIndices() << " : "
456  << getComposite().getType();
457 }
458 
459 LogicalResult spirv::CompositeExtractOp::verify() {
460  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
461  auto resultType =
462  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
463  if (!resultType)
464  return failure();
465 
466  if (resultType != getType()) {
467  return emitOpError("invalid result type: expected ")
468  << resultType << " but provided " << getType();
469  }
470 
471  return success();
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // spirv.CompositeInsert
476 //===----------------------------------------------------------------------===//
477 
478 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
479  Value object, Value composite,
480  ArrayRef<int32_t> indices) {
481  auto indexAttr = builder.getI32ArrayAttr(indices);
482  build(builder, state, composite.getType(), object, composite, indexAttr);
483 }
484 
485 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
486  OperationState &result) {
488  Type objectType, compositeType;
489  Attribute indicesAttr;
490  StringRef indicesAttrName =
491  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
492  auto loc = parser.getCurrentLocation();
493 
494  return failure(
495  parser.parseOperandList(operands, 2) ||
496  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
497  parser.parseColonType(objectType) ||
498  parser.parseKeywordType("into", compositeType) ||
499  parser.resolveOperands(operands, {objectType, compositeType}, loc,
500  result.operands) ||
501  parser.addTypesToList(compositeType, result.types));
502 }
503 
504 LogicalResult spirv::CompositeInsertOp::verify() {
505  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
506  auto objectType =
507  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
508  if (!objectType)
509  return failure();
510 
511  if (objectType != getObject().getType()) {
512  return emitOpError("object operand type should be ")
513  << objectType << ", but found " << getObject().getType();
514  }
515 
516  if (getComposite().getType() != getType()) {
517  return emitOpError("result type should be the same as "
518  "the composite type, but found ")
519  << getComposite().getType() << " vs " << getType();
520  }
521 
522  return success();
523 }
524 
526  printer << " " << getObject() << ", " << getComposite() << getIndices()
527  << " : " << getObject().getType() << " into "
528  << getComposite().getType();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // spirv.Constant
533 //===----------------------------------------------------------------------===//
534 
535 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
536  OperationState &result) {
537  Attribute value;
538  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
539  if (parser.parseAttribute(value, valueAttrName, result.attributes))
540  return failure();
541 
542  Type type = NoneType::get(parser.getContext());
543  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
544  type = typedAttr.getType();
545  if (llvm::isa<NoneType, TensorType>(type)) {
546  if (parser.parseColonType(type))
547  return failure();
548  }
549 
550  return parser.addTypeToList(type, result.types);
551 }
552 
553 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
554  printer << ' ' << getValue();
555  if (llvm::isa<spirv::ArrayType>(getType()))
556  printer << " : " << getType();
557 }
558 
559 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
560  Type opType) {
561  if (isa<spirv::CooperativeMatrixType>(opType)) {
562  auto denseAttr = dyn_cast<DenseElementsAttr>(value);
563  if (!denseAttr || !denseAttr.isSplat())
564  return op.emitOpError("expected a splat dense attribute for cooperative "
565  "matrix constant, but found ")
566  << denseAttr;
567  }
568  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
569  auto valueType = llvm::cast<TypedAttr>(value).getType();
570  if (valueType != opType)
571  return op.emitOpError("result type (")
572  << opType << ") does not match value type (" << valueType << ")";
573  return success();
574  }
575  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
576  auto valueType = llvm::cast<TypedAttr>(value).getType();
577  if (valueType == opType)
578  return success();
579  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
580  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
581  if (!arrayType)
582  return op.emitOpError("result or element type (")
583  << opType << ") does not match value type (" << valueType
584  << "), must be the same or spirv.array";
585 
586  int numElements = arrayType.getNumElements();
587  auto opElemType = arrayType.getElementType();
588  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
589  numElements *= t.getNumElements();
590  opElemType = t.getElementType();
591  }
592  if (!opElemType.isIntOrFloat())
593  return op.emitOpError("only support nested array result type");
594 
595  auto valueElemType = shapedType.getElementType();
596  if (valueElemType != opElemType) {
597  return op.emitOpError("result element type (")
598  << opElemType << ") does not match value element type ("
599  << valueElemType << ")";
600  }
601 
602  if (numElements != shapedType.getNumElements()) {
603  return op.emitOpError("result number of elements (")
604  << numElements << ") does not match value number of elements ("
605  << shapedType.getNumElements() << ")";
606  }
607  return success();
608  }
609  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
610  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
611  if (!arrayType)
612  return op.emitOpError(
613  "must have spirv.array result type for array value");
614  Type elemType = arrayType.getElementType();
615  for (Attribute element : arrayAttr.getValue()) {
616  // Verify array elements recursively.
617  if (failed(verifyConstantType(op, element, elemType)))
618  return failure();
619  }
620  return success();
621  }
622  return op.emitOpError("cannot have attribute: ") << value;
623 }
624 
625 LogicalResult spirv::ConstantOp::verify() {
626  // ODS already generates checks to make sure the result type is valid. We just
627  // need to additionally check that the value's attribute type is consistent
628  // with the result type.
629  return verifyConstantType(*this, getValueAttr(), getType());
630 }
631 
632 bool spirv::ConstantOp::isBuildableWith(Type type) {
633  // Must be valid SPIR-V type first.
634  if (!llvm::isa<spirv::SPIRVType>(type))
635  return false;
636 
637  if (isa<SPIRVDialect>(type.getDialect())) {
638  // TODO: support constant struct
639  return llvm::isa<spirv::ArrayType>(type);
640  }
641 
642  return true;
643 }
644 
645 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
646  OpBuilder &builder) {
647  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
648  unsigned width = intType.getWidth();
649  if (width == 1)
650  return builder.create<spirv::ConstantOp>(loc, type,
651  builder.getBoolAttr(false));
652  return builder.create<spirv::ConstantOp>(
653  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
654  }
655  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
656  return builder.create<spirv::ConstantOp>(
657  loc, type, builder.getFloatAttr(floatType, 0.0));
658  }
659  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
660  Type elemType = vectorType.getElementType();
661  if (llvm::isa<IntegerType>(elemType)) {
662  return builder.create<spirv::ConstantOp>(
663  loc, type,
664  DenseElementsAttr::get(vectorType,
665  IntegerAttr::get(elemType, 0).getValue()));
666  }
667  if (llvm::isa<FloatType>(elemType)) {
668  return builder.create<spirv::ConstantOp>(
669  loc, type,
670  DenseFPElementsAttr::get(vectorType,
671  FloatAttr::get(elemType, 0.0).getValue()));
672  }
673  }
674 
675  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
676 }
677 
678 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
679  OpBuilder &builder) {
680  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
681  unsigned width = intType.getWidth();
682  if (width == 1)
683  return builder.create<spirv::ConstantOp>(loc, type,
684  builder.getBoolAttr(true));
685  return builder.create<spirv::ConstantOp>(
686  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
687  }
688  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
689  return builder.create<spirv::ConstantOp>(
690  loc, type, builder.getFloatAttr(floatType, 1.0));
691  }
692  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
693  Type elemType = vectorType.getElementType();
694  if (llvm::isa<IntegerType>(elemType)) {
695  return builder.create<spirv::ConstantOp>(
696  loc, type,
697  DenseElementsAttr::get(vectorType,
698  IntegerAttr::get(elemType, 1).getValue()));
699  }
700  if (llvm::isa<FloatType>(elemType)) {
701  return builder.create<spirv::ConstantOp>(
702  loc, type,
703  DenseFPElementsAttr::get(vectorType,
704  FloatAttr::get(elemType, 1.0).getValue()));
705  }
706  }
707 
708  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
709 }
710 
711 void mlir::spirv::ConstantOp::getAsmResultNames(
712  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
713  Type type = getType();
714 
715  SmallString<32> specialNameBuffer;
716  llvm::raw_svector_ostream specialName(specialNameBuffer);
717  specialName << "cst";
718 
719  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
720 
721  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
722  if (intTy && intTy.getWidth() == 1) {
723  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
724  }
725 
726  if (intTy.isSignless()) {
727  specialName << intCst.getInt();
728  } else if (intTy.isUnsigned()) {
729  specialName << intCst.getUInt();
730  } else {
731  specialName << intCst.getSInt();
732  }
733  }
734 
735  if (intTy || llvm::isa<FloatType>(type)) {
736  specialName << '_' << type;
737  }
738 
739  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
740  specialName << "_vec_";
741  specialName << vecType.getDimSize(0);
742 
743  Type elementType = vecType.getElementType();
744 
745  if (llvm::isa<IntegerType>(elementType) ||
746  llvm::isa<FloatType>(elementType)) {
747  specialName << "x" << elementType;
748  }
749  }
750 
751  setNameFn(getResult(), specialName.str());
752 }
753 
754 void mlir::spirv::AddressOfOp::getAsmResultNames(
755  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
756  SmallString<32> specialNameBuffer;
757  llvm::raw_svector_ostream specialName(specialNameBuffer);
758  specialName << getVariable() << "_addr";
759  setNameFn(getResult(), specialName.str());
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // spirv.ControlBarrierOp
764 //===----------------------------------------------------------------------===//
765 
766 LogicalResult spirv::ControlBarrierOp::verify() {
767  return verifyMemorySemantics(getOperation(), getMemorySemantics());
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // spirv.EntryPoint
772 //===----------------------------------------------------------------------===//
773 
774 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
775  spirv::ExecutionModel executionModel,
776  spirv::FuncOp function,
777  ArrayRef<Attribute> interfaceVars) {
778  build(builder, state,
779  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
780  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
781 }
782 
783 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
784  OperationState &result) {
785  spirv::ExecutionModel execModel;
786  SmallVector<Attribute, 4> interfaceVars;
787 
789  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
790  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
791  return failure();
792  }
793 
794  if (!parser.parseOptionalComma()) {
795  // Parse the interface variables
796  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
797  // The name of the interface variable attribute isnt important
798  FlatSymbolRefAttr var;
799  NamedAttrList attrs;
800  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
801  return failure();
802  interfaceVars.push_back(var);
803  return success();
804  }))
805  return failure();
806  }
807  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
808  parser.getBuilder().getArrayAttr(interfaceVars));
809  return success();
810 }
811 
813  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
814  printer.printSymbolName(getFn());
815  auto interfaceVars = getInterface().getValue();
816  if (!interfaceVars.empty())
817  printer << ", " << llvm::interleaved(interfaceVars);
818 }
819 
820 LogicalResult spirv::EntryPointOp::verify() {
821  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
822  // verification.
823  return success();
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // spirv.ExecutionMode
828 //===----------------------------------------------------------------------===//
829 
830 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
831  spirv::FuncOp function,
832  spirv::ExecutionMode executionMode,
833  ArrayRef<int32_t> params) {
834  build(builder, state, SymbolRefAttr::get(function),
835  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
836  builder.getI32ArrayAttr(params));
837 }
838 
839 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
840  OperationState &result) {
841  spirv::ExecutionMode execMode;
842  Attribute fn;
843  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
844  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
845  return failure();
846  }
847 
849  Type i32Type = parser.getBuilder().getIntegerType(32);
850  while (!parser.parseOptionalComma()) {
851  NamedAttrList attr;
852  Attribute value;
853  if (parser.parseAttribute(value, i32Type, "value", attr)) {
854  return failure();
855  }
856  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
857  }
858  StringRef valuesAttrName =
859  spirv::ExecutionModeOp::getValuesAttrName(result.name);
860  result.addAttribute(valuesAttrName,
861  parser.getBuilder().getI32ArrayAttr(values));
862  return success();
863 }
864 
866  printer << " ";
867  printer.printSymbolName(getFn());
868  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
869  ArrayAttr values = this->getValues();
870  if (!values.empty())
871  printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // spirv.func
876 //===----------------------------------------------------------------------===//
877 
878 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
880  SmallVector<DictionaryAttr> resultAttrs;
881  SmallVector<Type> resultTypes;
882  auto &builder = parser.getBuilder();
883 
884  // Parse the name as a symbol.
885  StringAttr nameAttr;
886  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
887  result.attributes))
888  return failure();
889 
890  // Parse the function signature.
891  bool isVariadic = false;
893  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
894  resultAttrs))
895  return failure();
896 
897  SmallVector<Type> argTypes;
898  for (auto &arg : entryArgs)
899  argTypes.push_back(arg.type);
900  auto fnType = builder.getFunctionType(argTypes, resultTypes);
901  result.addAttribute(getFunctionTypeAttrName(result.name),
902  TypeAttr::get(fnType));
903 
904  // Parse the optional function control keyword.
905  spirv::FunctionControl fnControl;
906  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
907  return failure();
908 
909  // If additional attributes are present, parse them.
910  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
911  return failure();
912 
913  // Add the attributes to the function arguments.
914  assert(resultAttrs.size() == resultTypes.size());
916  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
917  getResAttrsAttrName(result.name));
918 
919  // Parse the optional function body.
920  auto *body = result.addRegion();
921  OptionalParseResult parseResult =
922  parser.parseOptionalRegion(*body, entryArgs);
923  return failure(parseResult.has_value() && failed(*parseResult));
924 }
925 
926 void spirv::FuncOp::print(OpAsmPrinter &printer) {
927  // Print function name, signature, and control.
928  printer << " ";
929  printer.printSymbolName(getSymName());
930  auto fnType = getFunctionType();
932  printer, *this, fnType.getInputs(),
933  /*isVariadic=*/false, fnType.getResults());
934  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
935  << "\"";
937  printer, *this,
938  {spirv::attributeName<spirv::FunctionControl>(),
939  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940  getFunctionControlAttrName()});
941 
942  // Print the body if this is not an external function.
943  Region &body = this->getBody();
944  if (!body.empty()) {
945  printer << ' ';
946  printer.printRegion(body, /*printEntryBlockArgs=*/false,
947  /*printBlockTerminators=*/true);
948  }
949 }
950 
951 LogicalResult spirv::FuncOp::verifyType() {
952  FunctionType fnType = getFunctionType();
953  if (fnType.getNumResults() > 1)
954  return emitOpError("cannot have more than one result");
955 
956  auto hasDecorationAttr = [&](spirv::Decoration decoration,
957  unsigned argIndex) {
958  auto func = llvm::cast<FunctionOpInterface>(getOperation());
959  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
960  if (argAttr.getName() != spirv::DecorationAttr::name)
961  continue;
962  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
963  return decAttr.getValue() == decoration;
964  }
965  return false;
966  };
967 
968  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969  Type param = fnType.getInputs()[i];
970  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
971  if (!inputPtrType)
972  continue;
973 
974  auto pointeePtrType =
975  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
976  if (pointeePtrType) {
977  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
978  // > If an OpFunctionParameter is a pointer (or contains a pointer)
979  // > and the type it points to is a pointer in the PhysicalStorageBuffer
980  // > storage class, the function parameter must be decorated with exactly
981  // > one of AliasedPointer or RestrictPointer.
982  if (pointeePtrType.getStorageClass() !=
983  spirv::StorageClass::PhysicalStorageBuffer)
984  continue;
985 
986  bool hasAliasedPtr =
987  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988  bool hasRestrictPtr =
989  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990  if (!hasAliasedPtr && !hasRestrictPtr)
991  return emitOpError()
992  << "with a pointer points to a physical buffer pointer must "
993  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
994  continue;
995  }
996  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
997  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
998  // > the PhysicalStorageBuffer storage class, the function parameter must
999  // > be decorated with exactly one of Aliased or Restrict.
1000  if (auto pointeeArrayType =
1001  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1002  pointeePtrType =
1003  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1004  } else {
1005  pointeePtrType = inputPtrType;
1006  }
1007 
1008  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009  spirv::StorageClass::PhysicalStorageBuffer)
1010  continue;
1011 
1012  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014  if (!hasAliased && !hasRestrict)
1015  return emitOpError() << "with physical buffer pointer must be decorated "
1016  "either 'Aliased' or 'Restrict'";
1017  }
1018 
1019  return success();
1020 }
1021 
1022 LogicalResult spirv::FuncOp::verifyBody() {
1023  FunctionType fnType = getFunctionType();
1024  if (!isExternal()) {
1025  Block &entryBlock = front();
1026 
1027  unsigned numArguments = this->getNumArguments();
1028  if (entryBlock.getNumArguments() != numArguments)
1029  return emitOpError("entry block must have ")
1030  << numArguments << " arguments to match function signature";
1031 
1032  for (auto [index, fnArgType, blockArgType] :
1033  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1034  if (blockArgType != fnArgType) {
1035  return emitOpError("type of entry block argument #")
1036  << index << '(' << blockArgType
1037  << ") must match the type of the corresponding argument in "
1038  << "function signature(" << fnArgType << ')';
1039  }
1040  }
1041  }
1042 
1043  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1044  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1045  if (fnType.getNumResults() != 0)
1046  return retOp.emitOpError("cannot be used in functions returning value");
1047  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1048  if (fnType.getNumResults() != 1)
1049  return retOp.emitOpError(
1050  "returns 1 value but enclosing function requires ")
1051  << fnType.getNumResults() << " results";
1052 
1053  auto retOperandType = retOp.getValue().getType();
1054  auto fnResultType = fnType.getResult(0);
1055  if (retOperandType != fnResultType)
1056  return retOp.emitOpError(" return value's type (")
1057  << retOperandType << ") mismatch with function's result type ("
1058  << fnResultType << ")";
1059  }
1060  return WalkResult::advance();
1061  });
1062 
1063  // TODO: verify other bits like linkage type.
1064 
1065  return failure(walkResult.wasInterrupted());
1066 }
1067 
1068 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1069  StringRef name, FunctionType type,
1070  spirv::FunctionControl control,
1071  ArrayRef<NamedAttribute> attrs) {
1072  state.addAttribute(SymbolTable::getSymbolAttrName(),
1073  builder.getStringAttr(name));
1074  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1075  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1076  builder.getAttr<spirv::FunctionControlAttr>(control));
1077  state.attributes.append(attrs.begin(), attrs.end());
1078  state.addRegion();
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // spirv.GLFClampOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1086  OperationState &result) {
1087  return parseOneResultSameOperandTypeOp(parser, result);
1088 }
1090 
1091 //===----------------------------------------------------------------------===//
1092 // spirv.GLUClampOp
1093 //===----------------------------------------------------------------------===//
1094 
1095 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1096  OperationState &result) {
1097  return parseOneResultSameOperandTypeOp(parser, result);
1098 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // spirv.GLSClampOp
1103 //===----------------------------------------------------------------------===//
1104 
1105 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1106  OperationState &result) {
1107  return parseOneResultSameOperandTypeOp(parser, result);
1108 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // spirv.GLFmaOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1116  return parseOneResultSameOperandTypeOp(parser, result);
1117 }
1118 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // spirv.GlobalVariable
1122 //===----------------------------------------------------------------------===//
1123 
1124 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1125  Type type, StringRef name,
1126  unsigned descriptorSet, unsigned binding) {
1127  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1128  state.addAttribute(
1129  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1130  builder.getI32IntegerAttr(descriptorSet));
1131  state.addAttribute(
1132  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1133  builder.getI32IntegerAttr(binding));
1134 }
1135 
1136 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1137  Type type, StringRef name,
1138  spirv::BuiltIn builtin) {
1139  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1140  state.addAttribute(
1141  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1142  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1143 }
1144 
1145 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1146  OperationState &result) {
1147  // Parse variable name.
1148  StringAttr nameAttr;
1149  StringRef initializerAttrName =
1150  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1151  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1152  result.attributes)) {
1153  return failure();
1154  }
1155 
1156  // Parse optional initializer
1157  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1158  FlatSymbolRefAttr initSymbol;
1159  if (parser.parseLParen() ||
1160  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1161  result.attributes) ||
1162  parser.parseRParen())
1163  return failure();
1164  }
1165 
1166  if (parseVariableDecorations(parser, result)) {
1167  return failure();
1168  }
1169 
1170  Type type;
1171  StringRef typeAttrName =
1172  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1173  auto loc = parser.getCurrentLocation();
1174  if (parser.parseColonType(type)) {
1175  return failure();
1176  }
1177  if (!llvm::isa<spirv::PointerType>(type)) {
1178  return parser.emitError(loc, "expected spirv.ptr type");
1179  }
1180  result.addAttribute(typeAttrName, TypeAttr::get(type));
1181 
1182  return success();
1183 }
1184 
1186  SmallVector<StringRef, 4> elidedAttrs{
1187  spirv::attributeName<spirv::StorageClass>()};
1188 
1189  // Print variable name.
1190  printer << ' ';
1191  printer.printSymbolName(getSymName());
1192  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1193 
1194  StringRef initializerAttrName = this->getInitializerAttrName();
1195  // Print optional initializer
1196  if (auto initializer = this->getInitializer()) {
1197  printer << " " << initializerAttrName << '(';
1198  printer.printSymbolName(*initializer);
1199  printer << ')';
1200  elidedAttrs.push_back(initializerAttrName);
1201  }
1202 
1203  StringRef typeAttrName = this->getTypeAttrName();
1204  elidedAttrs.push_back(typeAttrName);
1205  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1206  printer << " : " << getType();
1207 }
1208 
1209 LogicalResult spirv::GlobalVariableOp::verify() {
1210  if (!llvm::isa<spirv::PointerType>(getType()))
1211  return emitOpError("result must be of a !spv.ptr type");
1212 
1213  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1214  // object. It cannot be Generic. It must be the same as the Storage Class
1215  // operand of the Result Type."
1216  // Also, Function storage class is reserved by spirv.Variable.
1217  auto storageClass = this->storageClass();
1218  if (storageClass == spirv::StorageClass::Generic ||
1219  storageClass == spirv::StorageClass::Function) {
1220  return emitOpError("storage class cannot be '")
1221  << stringifyStorageClass(storageClass) << "'";
1222  }
1223 
1224  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1225  this->getInitializerAttrName())) {
1227  (*this)->getParentOp(), init.getAttr());
1228  // TODO: Currently only variable initialization with specialization
1229  // constants and other variables is supported. They could be normal
1230  // constants in the module scope as well.
1231  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1232  spirv::SpecConstantCompositeOp>(initOp)) {
1233  return emitOpError("initializer must be result of a "
1234  "spirv.SpecConstant or spirv.GlobalVariable or "
1235  "spirv.SpecConstantCompositeOp op");
1236  }
1237  }
1238 
1239  return success();
1240 }
1241 
1242 //===----------------------------------------------------------------------===//
1243 // spirv.INTEL.SubgroupBlockRead
1244 //===----------------------------------------------------------------------===//
1245 
1247  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1248  return failure();
1249 
1250  return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // spirv.INTEL.SubgroupBlockWrite
1255 //===----------------------------------------------------------------------===//
1256 
1258  OperationState &result) {
1259  // Parse the storage class specification
1260  spirv::StorageClass storageClass;
1262  auto loc = parser.getCurrentLocation();
1263  Type elementType;
1264  if (parseEnumStrAttr(storageClass, parser) ||
1265  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1266  parser.parseType(elementType)) {
1267  return failure();
1268  }
1269 
1270  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1271  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1272  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1273 
1274  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1275  result.operands)) {
1276  return failure();
1277  }
1278  return success();
1279 }
1280 
1282  printer << " " << getPtr() << ", " << getValue() << " : "
1283  << getValue().getType();
1284 }
1285 
1287  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1288  return failure();
1289 
1290  return success();
1291 }
1292 
1293 //===----------------------------------------------------------------------===//
1294 // spirv.IAddCarryOp
1295 //===----------------------------------------------------------------------===//
1296 
1297 LogicalResult spirv::IAddCarryOp::verify() {
1299 }
1300 
1301 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1302  OperationState &result) {
1304 }
1305 
1306 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1307  ::printArithmeticExtendedBinaryOp(*this, printer);
1308 }
1309 
1310 //===----------------------------------------------------------------------===//
1311 // spirv.ISubBorrowOp
1312 //===----------------------------------------------------------------------===//
1313 
1314 LogicalResult spirv::ISubBorrowOp::verify() {
1316 }
1317 
1318 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1319  OperationState &result) {
1321 }
1322 
1324  ::printArithmeticExtendedBinaryOp(*this, printer);
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // spirv.SMulExtended
1329 //===----------------------------------------------------------------------===//
1330 
1331 LogicalResult spirv::SMulExtendedOp::verify() {
1333 }
1334 
1335 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1336  OperationState &result) {
1338 }
1339 
1341  ::printArithmeticExtendedBinaryOp(*this, printer);
1342 }
1343 
1344 //===----------------------------------------------------------------------===//
1345 // spirv.UMulExtended
1346 //===----------------------------------------------------------------------===//
1347 
1348 LogicalResult spirv::UMulExtendedOp::verify() {
1350 }
1351 
1352 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1353  OperationState &result) {
1355 }
1356 
1358  ::printArithmeticExtendedBinaryOp(*this, printer);
1359 }
1360 
1361 //===----------------------------------------------------------------------===//
1362 // spirv.MemoryBarrierOp
1363 //===----------------------------------------------------------------------===//
1364 
1365 LogicalResult spirv::MemoryBarrierOp::verify() {
1366  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1367 }
1368 
1369 //===----------------------------------------------------------------------===//
1370 // spirv.module
1371 //===----------------------------------------------------------------------===//
1372 
1373 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1374  std::optional<StringRef> name) {
1375  OpBuilder::InsertionGuard guard(builder);
1376  builder.createBlock(state.addRegion());
1377  if (name) {
1378  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1379  builder.getStringAttr(*name));
1380  }
1381 }
1382 
1383 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1384  spirv::AddressingModel addressingModel,
1385  spirv::MemoryModel memoryModel,
1386  std::optional<VerCapExtAttr> vceTriple,
1387  std::optional<StringRef> name) {
1388  state.addAttribute(
1389  "addressing_model",
1390  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1391  state.addAttribute("memory_model",
1392  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1393  OpBuilder::InsertionGuard guard(builder);
1394  builder.createBlock(state.addRegion());
1395  if (vceTriple)
1396  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1397  if (name)
1398  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1399  builder.getStringAttr(*name));
1400 }
1401 
1402 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1403  OperationState &result) {
1404  Region *body = result.addRegion();
1405 
1406  // If the name is present, parse it.
1407  StringAttr nameAttr;
1408  (void)parser.parseOptionalSymbolName(
1409  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1410 
1411  // Parse attributes
1412  spirv::AddressingModel addrModel;
1413  spirv::MemoryModel memoryModel;
1414  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1415  result) ||
1416  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1417  result))
1418  return failure();
1419 
1420  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1421  spirv::VerCapExtAttr vceTriple;
1422  if (parser.parseAttribute(vceTriple,
1423  spirv::ModuleOp::getVCETripleAttrName(),
1424  result.attributes))
1425  return failure();
1426  }
1427 
1428  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1429  parser.parseRegion(*body, /*arguments=*/{}))
1430  return failure();
1431 
1432  // Make sure we have at least one block.
1433  if (body->empty())
1434  body->push_back(new Block());
1435 
1436  return success();
1437 }
1438 
1439 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1440  if (std::optional<StringRef> name = getName()) {
1441  printer << ' ';
1442  printer.printSymbolName(*name);
1443  }
1444 
1445  SmallVector<StringRef, 2> elidedAttrs;
1446 
1447  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1448  << spirv::stringifyMemoryModel(getMemoryModel());
1449  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1450  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1451  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1453 
1454  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1455  printer << " requires " << *triple;
1456  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1457  }
1458 
1459  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1460  printer << ' ';
1461  printer.printRegion(getRegion());
1462 }
1463 
1464 LogicalResult spirv::ModuleOp::verifyRegions() {
1465  Dialect *dialect = (*this)->getDialect();
1467  entryPoints;
1468  mlir::SymbolTable table(*this);
1469 
1470  for (auto &op : *getBody()) {
1471  if (op.getDialect() != dialect)
1472  return op.emitError("'spirv.module' can only contain spirv.* ops");
1473 
1474  // For EntryPoint op, check that the function and execution model is not
1475  // duplicated in EntryPointOps. Also verify that the interface specified
1476  // comes from globalVariables here to make this check cheaper.
1477  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1478  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1479  if (!funcOp) {
1480  return entryPointOp.emitError("function '")
1481  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1482  }
1483  if (auto interface = entryPointOp.getInterface()) {
1484  for (Attribute varRef : interface) {
1485  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1486  if (!varSymRef) {
1487  return entryPointOp.emitError(
1488  "expected symbol reference for interface "
1489  "specification instead of '")
1490  << varRef;
1491  }
1492  auto variableOp =
1493  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1494  if (!variableOp) {
1495  return entryPointOp.emitError("expected spirv.GlobalVariable "
1496  "symbol reference instead of'")
1497  << varSymRef << "'";
1498  }
1499  }
1500  }
1501 
1502  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1503  funcOp, entryPointOp.getExecutionModel());
1504  if (!entryPoints.try_emplace(key, entryPointOp).second)
1505  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1506  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1507  // If the function is external and does not have 'Import'
1508  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1509  // LinkageAttributes is used to import external functions.
1510  auto linkageAttr = funcOp.getLinkageAttributes();
1511  auto hasImportLinkage =
1512  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1513  spirv::LinkageType::Import);
1514  if (funcOp.isExternal() && !hasImportLinkage)
1515  return op.emitError(
1516  "'spirv.module' cannot contain external functions "
1517  "without 'Import' linkage_attributes (LinkageAttributes)");
1518 
1519  // TODO: move this check to spirv.func.
1520  for (auto &block : funcOp)
1521  for (auto &op : block) {
1522  if (op.getDialect() != dialect)
1523  return op.emitError(
1524  "functions in 'spirv.module' can only contain spirv.* ops");
1525  }
1526  }
1527  }
1528 
1529  return success();
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // spirv.mlir.referenceof
1534 //===----------------------------------------------------------------------===//
1535 
1536 LogicalResult spirv::ReferenceOfOp::verify() {
1537  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1538  (*this)->getParentOp(), getSpecConstAttr());
1539  Type constType;
1540 
1541  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1542  if (specConstOp)
1543  constType = specConstOp.getDefaultValue().getType();
1544 
1545  auto specConstCompositeOp =
1546  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1547  if (specConstCompositeOp)
1548  constType = specConstCompositeOp.getType();
1549 
1550  if (!specConstOp && !specConstCompositeOp)
1551  return emitOpError(
1552  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1553 
1554  if (getReference().getType() != constType)
1555  return emitOpError("result type mismatch with the referenced "
1556  "specialization constant's type");
1557 
1558  return success();
1559 }
1560 
1561 //===----------------------------------------------------------------------===//
1562 // spirv.SpecConstant
1563 //===----------------------------------------------------------------------===//
1564 
1565 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1566  OperationState &result) {
1567  StringAttr nameAttr;
1568  Attribute valueAttr;
1569  StringRef defaultValueAttrName =
1570  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1571 
1572  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1573  result.attributes))
1574  return failure();
1575 
1576  // Parse optional spec_id.
1577  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1578  IntegerAttr specIdAttr;
1579  if (parser.parseLParen() ||
1580  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1581  parser.parseRParen())
1582  return failure();
1583  }
1584 
1585  if (parser.parseEqual() ||
1586  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1587  return failure();
1588 
1589  return success();
1590 }
1591 
1593  printer << ' ';
1594  printer.printSymbolName(getSymName());
1595  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1596  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1597  printer << " = " << getDefaultValue();
1598 }
1599 
1600 LogicalResult spirv::SpecConstantOp::verify() {
1601  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1602  if (specID.getValue().isNegative())
1603  return emitOpError("SpecId cannot be negative");
1604 
1605  auto value = getDefaultValue();
1606  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1607  // Make sure bitwidth is allowed.
1608  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1609  return emitOpError("default value bitwidth disallowed");
1610  return success();
1611  }
1612  return emitOpError(
1613  "default value can only be a bool, integer, or float scalar");
1614 }
1615 
1616 //===----------------------------------------------------------------------===//
1617 // spirv.VectorShuffle
1618 //===----------------------------------------------------------------------===//
1619 
1620 LogicalResult spirv::VectorShuffleOp::verify() {
1621  VectorType resultType = llvm::cast<VectorType>(getType());
1622 
1623  size_t numResultElements = resultType.getNumElements();
1624  if (numResultElements != getComponents().size())
1625  return emitOpError("result type element count (")
1626  << numResultElements
1627  << ") mismatch with the number of component selectors ("
1628  << getComponents().size() << ")";
1629 
1630  size_t totalSrcElements =
1631  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1632  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1633 
1634  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1635  uint32_t index = selector.getZExtValue();
1636  if (index >= totalSrcElements &&
1637  index != std::numeric_limits<uint32_t>().max())
1638  return emitOpError("component selector ")
1639  << index << " out of range: expected to be in [0, "
1640  << totalSrcElements << ") or 0xffffffff";
1641  }
1642  return success();
1643 }
1644 
1645 //===----------------------------------------------------------------------===//
1646 // spirv.MatrixTimesScalar
1647 //===----------------------------------------------------------------------===//
1648 
1649 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1650  Type elementType =
1651  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1653  [](auto matrixType) { return matrixType.getElementType(); })
1654  .Default([](Type) { return nullptr; });
1655 
1656  assert(elementType && "Unhandled type");
1657 
1658  // Check that the scalar type is the same as the matrix element type.
1659  if (getScalar().getType() != elementType)
1660  return emitOpError("input matrix components' type and scaling value must "
1661  "have the same type");
1662 
1663  return success();
1664 }
1665 
1666 //===----------------------------------------------------------------------===//
1667 // spirv.Transpose
1668 //===----------------------------------------------------------------------===//
1669 
1670 LogicalResult spirv::TransposeOp::verify() {
1671  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1672  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1673 
1674  // Verify that the input and output matrices have correct shapes.
1675  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1676  return emitError("input matrix rows count must be equal to "
1677  "output matrix columns count");
1678 
1679  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1680  return emitError("input matrix columns count must be equal to "
1681  "output matrix rows count");
1682 
1683  // Verify that the input and output matrices have the same component type
1684  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1685  return emitError("input and output matrices must have the same "
1686  "component type");
1687 
1688  return success();
1689 }
1690 
1691 //===----------------------------------------------------------------------===//
1692 // spirv.MatrixTimesVector
1693 //===----------------------------------------------------------------------===//
1694 
1695 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1696  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1697  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1698  auto resultType = llvm::cast<VectorType>(getType());
1699 
1700  if (matrixType.getNumColumns() != vectorType.getNumElements())
1701  return emitOpError("matrix columns (")
1702  << matrixType.getNumColumns() << ") must match vector operand size ("
1703  << vectorType.getNumElements() << ")";
1704 
1705  if (resultType.getNumElements() != matrixType.getNumRows())
1706  return emitOpError("result size (")
1707  << resultType.getNumElements() << ") must match the matrix rows ("
1708  << matrixType.getNumRows() << ")";
1709 
1710  if (matrixType.getElementType() != resultType.getElementType())
1711  return emitOpError("matrix and result element types must match");
1712 
1713  return success();
1714 }
1715 
1716 //===----------------------------------------------------------------------===//
1717 // spirv.VectorTimesMatrix
1718 //===----------------------------------------------------------------------===//
1719 
1720 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1721  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1722  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1723  auto resultType = llvm::cast<VectorType>(getType());
1724 
1725  if (matrixType.getNumRows() != vectorType.getNumElements())
1726  return emitOpError("number of components in vector must equal the number "
1727  "of components in each column in matrix");
1728 
1729  if (resultType.getNumElements() != matrixType.getNumColumns())
1730  return emitOpError("number of columns in matrix must equal the number of "
1731  "components in result");
1732 
1733  if (matrixType.getElementType() != resultType.getElementType())
1734  return emitOpError("matrix must be a matrix with the same component type "
1735  "as the component type in result");
1736 
1737  return success();
1738 }
1739 
1740 //===----------------------------------------------------------------------===//
1741 // spirv.MatrixTimesMatrix
1742 //===----------------------------------------------------------------------===//
1743 
1744 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1745  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1746  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1747  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1748 
1749  // left matrix columns' count and right matrix rows' count must be equal
1750  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1751  return emitError("left matrix columns' count must be equal to "
1752  "the right matrix rows' count");
1753 
1754  // right and result matrices columns' count must be the same
1755  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1756  return emitError(
1757  "right and result matrices must have equal columns' count");
1758 
1759  // right and result matrices component type must be the same
1760  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1761  return emitError("right and result matrices' component type must"
1762  " be the same");
1763 
1764  // left and result matrices component type must be the same
1765  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1766  return emitError("left and result matrices' component type"
1767  " must be the same");
1768 
1769  // left and result matrices rows count must be the same
1770  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1771  return emitError("left and result matrices must have equal rows' count");
1772 
1773  return success();
1774 }
1775 
1776 //===----------------------------------------------------------------------===//
1777 // spirv.SpecConstantComposite
1778 //===----------------------------------------------------------------------===//
1779 
1781  OperationState &result) {
1782 
1783  StringAttr compositeName;
1784  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1785  result.attributes))
1786  return failure();
1787 
1788  if (parser.parseLParen())
1789  return failure();
1790 
1791  SmallVector<Attribute, 4> constituents;
1792 
1793  do {
1794  // The name of the constituent attribute isn't important
1795  const char *attrName = "spec_const";
1796  FlatSymbolRefAttr specConstRef;
1797  NamedAttrList attrs;
1798 
1799  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1800  return failure();
1801 
1802  constituents.push_back(specConstRef);
1803  } while (!parser.parseOptionalComma());
1804 
1805  if (parser.parseRParen())
1806  return failure();
1807 
1808  StringAttr compositeSpecConstituentsName =
1809  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1810  result.addAttribute(compositeSpecConstituentsName,
1811  parser.getBuilder().getArrayAttr(constituents));
1812 
1813  Type type;
1814  if (parser.parseColonType(type))
1815  return failure();
1816 
1817  StringAttr typeAttrName =
1818  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1819  result.addAttribute(typeAttrName, TypeAttr::get(type));
1820 
1821  return success();
1822 }
1823 
1825  printer << " ";
1826  printer.printSymbolName(getSymName());
1827  printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1828  << ") : " << getType();
1829 }
1830 
1832  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1833  auto constituents = this->getConstituents().getValue();
1834 
1835  if (!cType)
1836  return emitError("result type must be a composite type, but provided ")
1837  << getType();
1838 
1839  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1840  return emitError("unsupported composite type ") << cType;
1841  if (constituents.size() != cType.getNumElements())
1842  return emitError("has incorrect number of operands: expected ")
1843  << cType.getNumElements() << ", but provided "
1844  << constituents.size();
1845 
1846  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1847  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1848 
1849  auto constituentSpecConstOp =
1850  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1851  (*this)->getParentOp(), constituent.getAttr()));
1852 
1853  if (constituentSpecConstOp.getDefaultValue().getType() !=
1854  cType.getElementType(index))
1855  return emitError("has incorrect types of operands: expected ")
1856  << cType.getElementType(index) << ", but provided "
1857  << constituentSpecConstOp.getDefaultValue().getType();
1858  }
1859 
1860  return success();
1861 }
1862 
1863 //===----------------------------------------------------------------------===//
1864 // spirv.SpecConstantOperation
1865 //===----------------------------------------------------------------------===//
1866 
1868  OperationState &result) {
1869  Region *body = result.addRegion();
1870 
1871  if (parser.parseKeyword("wraps"))
1872  return failure();
1873 
1874  body->push_back(new Block);
1875  Block &block = body->back();
1876  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1877 
1878  if (!wrappedOp)
1879  return failure();
1880 
1881  OpBuilder builder(parser.getContext());
1882  builder.setInsertionPointToEnd(&block);
1883  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1884  result.location = wrappedOp->getLoc();
1885 
1886  result.addTypes(wrappedOp->getResult(0).getType());
1887 
1888  if (parser.parseOptionalAttrDict(result.attributes))
1889  return failure();
1890 
1891  return success();
1892 }
1893 
1895  printer << " wraps ";
1896  printer.printGenericOp(&getBody().front().front());
1897 }
1898 
1899 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1900  Block &block = getRegion().getBlocks().front();
1901 
1902  if (block.getOperations().size() != 2)
1903  return emitOpError("expected exactly 2 nested ops");
1904 
1905  Operation &enclosedOp = block.getOperations().front();
1906 
1908  return emitOpError("invalid enclosed op");
1909 
1910  for (auto operand : enclosedOp.getOperands())
1911  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1912  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1913  return emitOpError(
1914  "invalid operand, must be defined by a constant operation");
1915 
1916  return success();
1917 }
1918 
1919 //===----------------------------------------------------------------------===//
1920 // spirv.GL.FrexpStruct
1921 //===----------------------------------------------------------------------===//
1922 
1923 LogicalResult spirv::GLFrexpStructOp::verify() {
1924  spirv::StructType structTy =
1925  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1926 
1927  if (structTy.getNumElements() != 2)
1928  return emitError("result type must be a struct type with two memebers");
1929 
1930  Type significandTy = structTy.getElementType(0);
1931  Type exponentTy = structTy.getElementType(1);
1932  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1933  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1934 
1935  Type operandTy = getOperand().getType();
1936  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1937  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1938 
1939  if (significandTy != operandTy)
1940  return emitError("member zero of the resulting struct type must be the "
1941  "same type as the operand");
1942 
1943  if (exponentVecTy) {
1944  IntegerType componentIntTy =
1945  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1946  if (!componentIntTy || componentIntTy.getWidth() != 32)
1947  return emitError("member one of the resulting struct type must"
1948  "be a scalar or vector of 32 bit integer type");
1949  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1950  return emitError("member one of the resulting struct type "
1951  "must be a scalar or vector of 32 bit integer type");
1952  }
1953 
1954  // Check that the two member types have the same number of components
1955  if (operandVecTy && exponentVecTy &&
1956  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1957  return success();
1958 
1959  if (operandFTy && exponentIntTy)
1960  return success();
1961 
1962  return emitError("member one of the resulting struct type must have the same "
1963  "number of components as the operand type");
1964 }
1965 
1966 //===----------------------------------------------------------------------===//
1967 // spirv.GL.Ldexp
1968 //===----------------------------------------------------------------------===//
1969 
1970 LogicalResult spirv::GLLdexpOp::verify() {
1971  Type significandType = getX().getType();
1972  Type exponentType = getExp().getType();
1973 
1974  if (llvm::isa<FloatType>(significandType) !=
1975  llvm::isa<IntegerType>(exponentType))
1976  return emitOpError("operands must both be scalars or vectors");
1977 
1978  auto getNumElements = [](Type type) -> unsigned {
1979  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1980  return vectorType.getNumElements();
1981  return 1;
1982  };
1983 
1984  if (getNumElements(significandType) != getNumElements(exponentType))
1985  return emitOpError("operands must have the same number of elements");
1986 
1987  return success();
1988 }
1989 
1990 //===----------------------------------------------------------------------===//
1991 // spirv.ShiftLeftLogicalOp
1992 //===----------------------------------------------------------------------===//
1993 
1994 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
1995  return verifyShiftOp(*this);
1996 }
1997 
1998 //===----------------------------------------------------------------------===//
1999 // spirv.ShiftRightArithmeticOp
2000 //===----------------------------------------------------------------------===//
2001 
2002 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2003  return verifyShiftOp(*this);
2004 }
2005 
2006 //===----------------------------------------------------------------------===//
2007 // spirv.ShiftRightLogicalOp
2008 //===----------------------------------------------------------------------===//
2009 
2010 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2011  return verifyShiftOp(*this);
2012 }
2013 
2014 //===----------------------------------------------------------------------===//
2015 // spirv.VectorTimesScalarOp
2016 //===----------------------------------------------------------------------===//
2017 
2018 LogicalResult spirv::VectorTimesScalarOp::verify() {
2019  if (getVector().getType() != getType())
2020  return emitOpError("vector operand and result type mismatch");
2021  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2022  if (getScalar().getType() != scalarType)
2023  return emitOpError("scalar operand and result element type match");
2024  return success();
2025 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:269
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:559
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:122
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:255
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:301
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:171
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:151
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:293
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:274
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:252
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:78
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
SPIR-V struct type.
Definition: SPIRVTypes.h:293
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:71
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.