MLIR  21.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 #include <type_traits>
42 
43 using namespace mlir;
44 using namespace mlir::spirv::AttrNames;
45 
46 //===----------------------------------------------------------------------===//
47 // Common utility functions
48 //===----------------------------------------------------------------------===//
49 
50 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
51  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
52  if (!constOp) {
53  return failure();
54  }
55  auto valueAttr = constOp.getValue();
56  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
57  if (!integerValueAttr) {
58  return failure();
59  }
60 
61  if (integerValueAttr.getType().isSignlessInteger())
62  value = integerValueAttr.getInt();
63  else
64  value = integerValueAttr.getSInt();
65 
66  return success();
67 }
68 
69 LogicalResult
71  spirv::MemorySemantics memorySemantics) {
72  // According to the SPIR-V specification:
73  // "Despite being a mask and allowing multiple bits to be combined, it is
74  // invalid for more than one of these four bits to be set: Acquire, Release,
75  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
76  // Release semantics is done by setting the AcquireRelease bit, not by setting
77  // two bits."
78  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
79  spirv::MemorySemantics::Release |
80  spirv::MemorySemantics::AcquireRelease |
81  spirv::MemorySemantics::SequentiallyConsistent;
82 
83  auto bitCount =
84  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
85  if (bitCount > 1) {
86  return op->emitError(
87  "expected at most one of these four memory constraints "
88  "to be set: `Acquire`, `Release`,"
89  "`AcquireRelease` or `SequentiallyConsistent`");
90  }
91  return success();
92 }
93 
95  SmallVectorImpl<StringRef> &elidedAttrs) {
96  // Print optional descriptor binding
97  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
98  stringifyDecoration(spirv::Decoration::DescriptorSet));
99  auto bindingName = llvm::convertToSnakeFromCamelCase(
100  stringifyDecoration(spirv::Decoration::Binding));
101  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
102  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
103  if (descriptorSet && binding) {
104  elidedAttrs.push_back(descriptorSetName);
105  elidedAttrs.push_back(bindingName);
106  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
107  << ")";
108  }
109 
110  // Print BuiltIn attribute if present
111  auto builtInName = llvm::convertToSnakeFromCamelCase(
112  stringifyDecoration(spirv::Decoration::BuiltIn));
113  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
114  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
115  elidedAttrs.push_back(builtInName);
116  }
117 
118  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
119 }
120 
121 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
122  OperationState &result) {
124  Type type;
125  // If the operand list is in-between parentheses, then we have a generic form.
126  // (see the fallback in `printOneResultOp`).
127  SMLoc loc = parser.getCurrentLocation();
128  if (!parser.parseOptionalLParen()) {
129  if (parser.parseOperandList(ops) || parser.parseRParen() ||
130  parser.parseOptionalAttrDict(result.attributes) ||
131  parser.parseColon() || parser.parseType(type))
132  return failure();
133  auto fnType = llvm::dyn_cast<FunctionType>(type);
134  if (!fnType) {
135  parser.emitError(loc, "expected function type");
136  return failure();
137  }
138  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
139  return failure();
140  result.addTypes(fnType.getResults());
141  return success();
142  }
143  return failure(parser.parseOperandList(ops) ||
144  parser.parseOptionalAttrDict(result.attributes) ||
145  parser.parseColonType(type) ||
146  parser.resolveOperands(ops, type, result.operands) ||
147  parser.addTypeToList(type, result.types));
148 }
149 
151  assert(op->getNumResults() == 1 && "op should have one result");
152 
153  // If not all the operand and result types are the same, just use the
154  // generic assembly form to avoid omitting information in printing.
155  auto resultType = op->getResult(0).getType();
156  if (llvm::any_of(op->getOperandTypes(),
157  [&](Type type) { return type != resultType; })) {
158  p.printGenericOp(op, /*printOpName=*/false);
159  return;
160  }
161 
162  p << ' ';
163  p.printOperands(op->getOperands());
165  // Now we can output only one type for all operands and the result.
166  p << " : " << resultType;
167 }
168 
169 template <typename BlockReadWriteOpTy>
170 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
171  Value ptr, Value val) {
172  auto valType = val.getType();
173  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
174  valType = valVecTy.getElementType();
175 
176  if (valType !=
177  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
178  return op.emitOpError("mismatch in result type and pointer type");
179  }
180  return success();
181 }
182 
183 /// Walks the given type hierarchy with the given indices, potentially down
184 /// to component granularity, to select an element type. Returns null type and
185 /// emits errors with the given loc on failure.
186 static Type
188  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
189  if (indices.empty()) {
190  emitErrorFn("expected at least one index for spirv.CompositeExtract");
191  return nullptr;
192  }
193 
194  for (auto index : indices) {
195  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
196  if (cType.hasCompileTimeKnownNumElements() &&
197  (index < 0 ||
198  static_cast<uint64_t>(index) >= cType.getNumElements())) {
199  emitErrorFn("index ") << index << " out of bounds for " << type;
200  return nullptr;
201  }
202  type = cType.getElementType(index);
203  } else {
204  emitErrorFn("cannot extract from non-composite type ")
205  << type << " with index " << index;
206  return nullptr;
207  }
208  }
209  return type;
210 }
211 
212 static Type
214  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
215  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
216  if (!indicesArrayAttr) {
217  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
218  return nullptr;
219  }
220  if (indicesArrayAttr.empty()) {
221  emitErrorFn("expected at least one index for spirv.CompositeExtract");
222  return nullptr;
223  }
224 
225  SmallVector<int32_t, 2> indexVals;
226  for (auto indexAttr : indicesArrayAttr) {
227  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
228  if (!indexIntAttr) {
229  emitErrorFn("expected an 32-bit integer for index, but found '")
230  << indexAttr << "'";
231  return nullptr;
232  }
233  indexVals.push_back(indexIntAttr.getInt());
234  }
235  return getElementType(type, indexVals, emitErrorFn);
236 }
237 
238 static Type getElementType(Type type, Attribute indices, Location loc) {
239  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
240  return ::mlir::emitError(loc, err);
241  };
242  return getElementType(type, indices, errorFn);
243 }
244 
245 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
246  SMLoc loc) {
247  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
248  return parser.emitError(loc, err);
249  };
250  return getElementType(type, indices, errorFn);
251 }
252 
253 template <typename ExtendedBinaryOp>
254 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
255  auto resultType = llvm::cast<spirv::StructType>(op.getType());
256  if (resultType.getNumElements() != 2)
257  return op.emitOpError("expected result struct type containing two members");
258 
259  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
260  resultType.getElementType(0),
261  resultType.getElementType(1)}))
262  return op.emitOpError(
263  "expected all operand types and struct member types are the same");
264 
265  return success();
266 }
267 
268 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
269  OperationState &result) {
271  if (parser.parseOptionalAttrDict(result.attributes) ||
272  parser.parseOperandList(operands) || parser.parseColon())
273  return failure();
274 
275  Type resultType;
276  SMLoc loc = parser.getCurrentLocation();
277  if (parser.parseType(resultType))
278  return failure();
279 
280  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
281  if (!structType || structType.getNumElements() != 2)
282  return parser.emitError(loc, "expected spirv.struct type with two members");
283 
284  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
285  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
286  return failure();
287 
288  result.addTypes(resultType);
289  return success();
290 }
291 
293  OpAsmPrinter &printer) {
294  printer << ' ';
295  printer.printOptionalAttrDict(op->getAttrs());
296  printer.printOperands(op->getOperands());
297  printer << " : " << op->getResultTypes().front();
298 }
299 
300 static LogicalResult verifyShiftOp(Operation *op) {
301  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
302  return op->emitError("expected the same type for the first operand and "
303  "result, but provided ")
304  << op->getOperand(0).getType() << " and "
305  << op->getResult(0).getType();
306  }
307  return success();
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // spirv.mlir.addressof
312 //===----------------------------------------------------------------------===//
313 
314 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
315  spirv::GlobalVariableOp var) {
316  build(builder, state, var.getType(), SymbolRefAttr::get(var));
317 }
318 
319 LogicalResult spirv::AddressOfOp::verify() {
320  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
321  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
322  getVariableAttr()));
323  if (!varOp) {
324  return emitOpError("expected spirv.GlobalVariable symbol");
325  }
326  if (getPointer().getType() != varOp.getType()) {
327  return emitOpError(
328  "result type mismatch with the referenced global variable's type");
329  }
330  return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // spirv.CompositeConstruct
335 //===----------------------------------------------------------------------===//
336 
337 LogicalResult spirv::CompositeConstructOp::verify() {
338  operand_range constituents = this->getConstituents();
339 
340  // There are 4 cases with varying verification rules:
341  // 1. Cooperative Matrices (1 constituent)
342  // 2. Structs (1 constituent for each member)
343  // 3. Arrays (1 constituent for each array element)
344  // 4. Vectors (1 constituent (sub-)element for each vector element)
345 
346  auto coopElementType =
349  [](auto coopType) { return coopType.getElementType(); })
350  .Default([](Type) { return nullptr; });
351 
352  // Case 1. -- matrices.
353  if (coopElementType) {
354  if (constituents.size() != 1)
355  return emitOpError("has incorrect number of operands: expected ")
356  << "1, but provided " << constituents.size();
357  if (coopElementType != constituents.front().getType())
358  return emitOpError("operand type mismatch: expected operand type ")
359  << coopElementType << ", but provided "
360  << constituents.front().getType();
361  return success();
362  }
363 
364  // Case 2./3./4. -- number of constituents matches the number of elements.
365  auto cType = llvm::cast<spirv::CompositeType>(getType());
366  if (constituents.size() == cType.getNumElements()) {
367  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
368  if (constituents[index].getType() != cType.getElementType(index)) {
369  return emitOpError("operand type mismatch: expected operand type ")
370  << cType.getElementType(index) << ", but provided "
371  << constituents[index].getType();
372  }
373  }
374  return success();
375  }
376 
377  // Case 4. -- check that all constituents add up tp the expected vector type.
378  auto resultType = llvm::dyn_cast<VectorType>(cType);
379  if (!resultType)
380  return emitOpError(
381  "expected to return a vector or cooperative matrix when the number of "
382  "constituents is less than what the result needs");
383 
384  SmallVector<unsigned> sizes;
385  for (Value component : constituents) {
386  if (!llvm::isa<VectorType>(component.getType()) &&
387  !component.getType().isIntOrFloat())
388  return emitOpError("operand type mismatch: expected operand to have "
389  "a scalar or vector type, but provided ")
390  << component.getType();
391 
392  Type elementType = component.getType();
393  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
394  sizes.push_back(vectorType.getNumElements());
395  elementType = vectorType.getElementType();
396  } else {
397  sizes.push_back(1);
398  }
399 
400  if (elementType != resultType.getElementType())
401  return emitOpError("operand element type mismatch: expected to be ")
402  << resultType.getElementType() << ", but provided " << elementType;
403  }
404  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
405  if (totalCount != cType.getNumElements())
406  return emitOpError("has incorrect number of operands: expected ")
407  << cType.getNumElements() << ", but provided " << totalCount;
408  return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // spirv.CompositeExtractOp
413 //===----------------------------------------------------------------------===//
414 
415 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
416  Value composite,
417  ArrayRef<int32_t> indices) {
418  auto indexAttr = builder.getI32ArrayAttr(indices);
419  auto elementType =
420  getElementType(composite.getType(), indexAttr, state.location);
421  if (!elementType) {
422  return;
423  }
424  build(builder, state, elementType, composite, indexAttr);
425 }
426 
428  OperationState &result) {
429  OpAsmParser::UnresolvedOperand compositeInfo;
430  Attribute indicesAttr;
431  StringRef indicesAttrName =
432  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
433  Type compositeType;
434  SMLoc attrLocation;
435 
436  if (parser.parseOperand(compositeInfo) ||
437  parser.getCurrentLocation(&attrLocation) ||
438  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
439  parser.parseColonType(compositeType) ||
440  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
441  return failure();
442  }
443 
444  Type resultType =
445  getElementType(compositeType, indicesAttr, parser, attrLocation);
446  if (!resultType) {
447  return failure();
448  }
449  result.addTypes(resultType);
450  return success();
451 }
452 
454  printer << ' ' << getComposite() << getIndices() << " : "
455  << getComposite().getType();
456 }
457 
458 LogicalResult spirv::CompositeExtractOp::verify() {
459  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
460  auto resultType =
461  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
462  if (!resultType)
463  return failure();
464 
465  if (resultType != getType()) {
466  return emitOpError("invalid result type: expected ")
467  << resultType << " but provided " << getType();
468  }
469 
470  return success();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // spirv.CompositeInsert
475 //===----------------------------------------------------------------------===//
476 
477 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
478  Value object, Value composite,
479  ArrayRef<int32_t> indices) {
480  auto indexAttr = builder.getI32ArrayAttr(indices);
481  build(builder, state, composite.getType(), object, composite, indexAttr);
482 }
483 
484 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
485  OperationState &result) {
487  Type objectType, compositeType;
488  Attribute indicesAttr;
489  StringRef indicesAttrName =
490  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
491  auto loc = parser.getCurrentLocation();
492 
493  return failure(
494  parser.parseOperandList(operands, 2) ||
495  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
496  parser.parseColonType(objectType) ||
497  parser.parseKeywordType("into", compositeType) ||
498  parser.resolveOperands(operands, {objectType, compositeType}, loc,
499  result.operands) ||
500  parser.addTypesToList(compositeType, result.types));
501 }
502 
503 LogicalResult spirv::CompositeInsertOp::verify() {
504  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
505  auto objectType =
506  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
507  if (!objectType)
508  return failure();
509 
510  if (objectType != getObject().getType()) {
511  return emitOpError("object operand type should be ")
512  << objectType << ", but found " << getObject().getType();
513  }
514 
515  if (getComposite().getType() != getType()) {
516  return emitOpError("result type should be the same as "
517  "the composite type, but found ")
518  << getComposite().getType() << " vs " << getType();
519  }
520 
521  return success();
522 }
523 
525  printer << " " << getObject() << ", " << getComposite() << getIndices()
526  << " : " << getObject().getType() << " into "
527  << getComposite().getType();
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // spirv.Constant
532 //===----------------------------------------------------------------------===//
533 
534 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
535  OperationState &result) {
536  Attribute value;
537  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
538  if (parser.parseAttribute(value, valueAttrName, result.attributes))
539  return failure();
540 
541  Type type = NoneType::get(parser.getContext());
542  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
543  type = typedAttr.getType();
544  if (llvm::isa<NoneType, TensorType>(type)) {
545  if (parser.parseColonType(type))
546  return failure();
547  }
548 
549  return parser.addTypeToList(type, result.types);
550 }
551 
552 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
553  printer << ' ' << getValue();
554  if (llvm::isa<spirv::ArrayType>(getType()))
555  printer << " : " << getType();
556 }
557 
558 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
559  Type opType) {
560  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
561  auto valueType = llvm::cast<TypedAttr>(value).getType();
562  if (valueType != opType)
563  return op.emitOpError("result type (")
564  << opType << ") does not match value type (" << valueType << ")";
565  return success();
566  }
567  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
568  auto valueType = llvm::cast<TypedAttr>(value).getType();
569  if (valueType == opType)
570  return success();
571  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
572  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
573  if (!arrayType)
574  return op.emitOpError("result or element type (")
575  << opType << ") does not match value type (" << valueType
576  << "), must be the same or spirv.array";
577 
578  int numElements = arrayType.getNumElements();
579  auto opElemType = arrayType.getElementType();
580  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
581  numElements *= t.getNumElements();
582  opElemType = t.getElementType();
583  }
584  if (!opElemType.isIntOrFloat())
585  return op.emitOpError("only support nested array result type");
586 
587  auto valueElemType = shapedType.getElementType();
588  if (valueElemType != opElemType) {
589  return op.emitOpError("result element type (")
590  << opElemType << ") does not match value element type ("
591  << valueElemType << ")";
592  }
593 
594  if (numElements != shapedType.getNumElements()) {
595  return op.emitOpError("result number of elements (")
596  << numElements << ") does not match value number of elements ("
597  << shapedType.getNumElements() << ")";
598  }
599  return success();
600  }
601  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
602  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
603  if (!arrayType)
604  return op.emitOpError(
605  "must have spirv.array result type for array value");
606  Type elemType = arrayType.getElementType();
607  for (Attribute element : arrayAttr.getValue()) {
608  // Verify array elements recursively.
609  if (failed(verifyConstantType(op, element, elemType)))
610  return failure();
611  }
612  return success();
613  }
614  return op.emitOpError("cannot have attribute: ") << value;
615 }
616 
617 LogicalResult spirv::ConstantOp::verify() {
618  // ODS already generates checks to make sure the result type is valid. We just
619  // need to additionally check that the value's attribute type is consistent
620  // with the result type.
621  return verifyConstantType(*this, getValueAttr(), getType());
622 }
623 
624 bool spirv::ConstantOp::isBuildableWith(Type type) {
625  // Must be valid SPIR-V type first.
626  if (!llvm::isa<spirv::SPIRVType>(type))
627  return false;
628 
629  if (isa<SPIRVDialect>(type.getDialect())) {
630  // TODO: support constant struct
631  return llvm::isa<spirv::ArrayType>(type);
632  }
633 
634  return true;
635 }
636 
637 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
638  OpBuilder &builder) {
639  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
640  unsigned width = intType.getWidth();
641  if (width == 1)
642  return builder.create<spirv::ConstantOp>(loc, type,
643  builder.getBoolAttr(false));
644  return builder.create<spirv::ConstantOp>(
645  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
646  }
647  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
648  return builder.create<spirv::ConstantOp>(
649  loc, type, builder.getFloatAttr(floatType, 0.0));
650  }
651  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
652  Type elemType = vectorType.getElementType();
653  if (llvm::isa<IntegerType>(elemType)) {
654  return builder.create<spirv::ConstantOp>(
655  loc, type,
656  DenseElementsAttr::get(vectorType,
657  IntegerAttr::get(elemType, 0).getValue()));
658  }
659  if (llvm::isa<FloatType>(elemType)) {
660  return builder.create<spirv::ConstantOp>(
661  loc, type,
662  DenseFPElementsAttr::get(vectorType,
663  FloatAttr::get(elemType, 0.0).getValue()));
664  }
665  }
666 
667  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
668 }
669 
670 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
671  OpBuilder &builder) {
672  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
673  unsigned width = intType.getWidth();
674  if (width == 1)
675  return builder.create<spirv::ConstantOp>(loc, type,
676  builder.getBoolAttr(true));
677  return builder.create<spirv::ConstantOp>(
678  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
679  }
680  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
681  return builder.create<spirv::ConstantOp>(
682  loc, type, builder.getFloatAttr(floatType, 1.0));
683  }
684  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
685  Type elemType = vectorType.getElementType();
686  if (llvm::isa<IntegerType>(elemType)) {
687  return builder.create<spirv::ConstantOp>(
688  loc, type,
689  DenseElementsAttr::get(vectorType,
690  IntegerAttr::get(elemType, 1).getValue()));
691  }
692  if (llvm::isa<FloatType>(elemType)) {
693  return builder.create<spirv::ConstantOp>(
694  loc, type,
695  DenseFPElementsAttr::get(vectorType,
696  FloatAttr::get(elemType, 1.0).getValue()));
697  }
698  }
699 
700  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
701 }
702 
703 void mlir::spirv::ConstantOp::getAsmResultNames(
704  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
705  Type type = getType();
706 
707  SmallString<32> specialNameBuffer;
708  llvm::raw_svector_ostream specialName(specialNameBuffer);
709  specialName << "cst";
710 
711  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
712 
713  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
714  if (intTy && intTy.getWidth() == 1) {
715  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
716  }
717 
718  if (intTy.isSignless()) {
719  specialName << intCst.getInt();
720  } else if (intTy.isUnsigned()) {
721  specialName << intCst.getUInt();
722  } else {
723  specialName << intCst.getSInt();
724  }
725  }
726 
727  if (intTy || llvm::isa<FloatType>(type)) {
728  specialName << '_' << type;
729  }
730 
731  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
732  specialName << "_vec_";
733  specialName << vecType.getDimSize(0);
734 
735  Type elementType = vecType.getElementType();
736 
737  if (llvm::isa<IntegerType>(elementType) ||
738  llvm::isa<FloatType>(elementType)) {
739  specialName << "x" << elementType;
740  }
741  }
742 
743  setNameFn(getResult(), specialName.str());
744 }
745 
746 void mlir::spirv::AddressOfOp::getAsmResultNames(
747  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
748  SmallString<32> specialNameBuffer;
749  llvm::raw_svector_ostream specialName(specialNameBuffer);
750  specialName << getVariable() << "_addr";
751  setNameFn(getResult(), specialName.str());
752 }
753 
754 //===----------------------------------------------------------------------===//
755 // spirv.ControlBarrierOp
756 //===----------------------------------------------------------------------===//
757 
758 LogicalResult spirv::ControlBarrierOp::verify() {
759  return verifyMemorySemantics(getOperation(), getMemorySemantics());
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // spirv.EntryPoint
764 //===----------------------------------------------------------------------===//
765 
766 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
767  spirv::ExecutionModel executionModel,
768  spirv::FuncOp function,
769  ArrayRef<Attribute> interfaceVars) {
770  build(builder, state,
771  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
772  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
773 }
774 
775 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
776  OperationState &result) {
777  spirv::ExecutionModel execModel;
779  SmallVector<Type, 0> idTypes;
780  SmallVector<Attribute, 4> interfaceVars;
781 
783  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
784  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
785  return failure();
786  }
787 
788  if (!parser.parseOptionalComma()) {
789  // Parse the interface variables
790  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
791  // The name of the interface variable attribute isnt important
792  FlatSymbolRefAttr var;
793  NamedAttrList attrs;
794  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
795  return failure();
796  interfaceVars.push_back(var);
797  return success();
798  }))
799  return failure();
800  }
801  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
802  parser.getBuilder().getArrayAttr(interfaceVars));
803  return success();
804 }
805 
807  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
808  printer.printSymbolName(getFn());
809  auto interfaceVars = getInterface().getValue();
810  if (!interfaceVars.empty()) {
811  printer << ", ";
812  llvm::interleaveComma(interfaceVars, printer);
813  }
814 }
815 
816 LogicalResult spirv::EntryPointOp::verify() {
817  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
818  // verification.
819  return success();
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // spirv.ExecutionMode
824 //===----------------------------------------------------------------------===//
825 
826 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
827  spirv::FuncOp function,
828  spirv::ExecutionMode executionMode,
829  ArrayRef<int32_t> params) {
830  build(builder, state, SymbolRefAttr::get(function),
831  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
832  builder.getI32ArrayAttr(params));
833 }
834 
835 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
836  OperationState &result) {
837  spirv::ExecutionMode execMode;
838  Attribute fn;
839  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
840  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
841  return failure();
842  }
843 
845  Type i32Type = parser.getBuilder().getIntegerType(32);
846  while (!parser.parseOptionalComma()) {
847  NamedAttrList attr;
848  Attribute value;
849  if (parser.parseAttribute(value, i32Type, "value", attr)) {
850  return failure();
851  }
852  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
853  }
854  StringRef valuesAttrName =
855  spirv::ExecutionModeOp::getValuesAttrName(result.name);
856  result.addAttribute(valuesAttrName,
857  parser.getBuilder().getI32ArrayAttr(values));
858  return success();
859 }
860 
862  printer << " ";
863  printer.printSymbolName(getFn());
864  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
865  auto values = this->getValues();
866  if (values.empty())
867  return;
868  printer << ", ";
869  llvm::interleaveComma(values, printer, [&](Attribute a) {
870  printer << llvm::cast<IntegerAttr>(a).getInt();
871  });
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // spirv.func
876 //===----------------------------------------------------------------------===//
877 
878 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
880  SmallVector<DictionaryAttr> resultAttrs;
881  SmallVector<Type> resultTypes;
882  auto &builder = parser.getBuilder();
883 
884  // Parse the name as a symbol.
885  StringAttr nameAttr;
886  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
887  result.attributes))
888  return failure();
889 
890  // Parse the function signature.
891  bool isVariadic = false;
893  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
894  resultAttrs))
895  return failure();
896 
897  SmallVector<Type> argTypes;
898  for (auto &arg : entryArgs)
899  argTypes.push_back(arg.type);
900  auto fnType = builder.getFunctionType(argTypes, resultTypes);
901  result.addAttribute(getFunctionTypeAttrName(result.name),
902  TypeAttr::get(fnType));
903 
904  // Parse the optional function control keyword.
905  spirv::FunctionControl fnControl;
906  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
907  return failure();
908 
909  // If additional attributes are present, parse them.
910  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
911  return failure();
912 
913  // Add the attributes to the function arguments.
914  assert(resultAttrs.size() == resultTypes.size());
916  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
917  getResAttrsAttrName(result.name));
918 
919  // Parse the optional function body.
920  auto *body = result.addRegion();
921  OptionalParseResult parseResult =
922  parser.parseOptionalRegion(*body, entryArgs);
923  return failure(parseResult.has_value() && failed(*parseResult));
924 }
925 
926 void spirv::FuncOp::print(OpAsmPrinter &printer) {
927  // Print function name, signature, and control.
928  printer << " ";
929  printer.printSymbolName(getSymName());
930  auto fnType = getFunctionType();
932  printer, *this, fnType.getInputs(),
933  /*isVariadic=*/false, fnType.getResults());
934  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
935  << "\"";
937  printer, *this,
938  {spirv::attributeName<spirv::FunctionControl>(),
939  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940  getFunctionControlAttrName()});
941 
942  // Print the body if this is not an external function.
943  Region &body = this->getBody();
944  if (!body.empty()) {
945  printer << ' ';
946  printer.printRegion(body, /*printEntryBlockArgs=*/false,
947  /*printBlockTerminators=*/true);
948  }
949 }
950 
951 LogicalResult spirv::FuncOp::verifyType() {
952  FunctionType fnType = getFunctionType();
953  if (fnType.getNumResults() > 1)
954  return emitOpError("cannot have more than one result");
955 
956  auto hasDecorationAttr = [&](spirv::Decoration decoration,
957  unsigned argIndex) {
958  auto func = llvm::cast<FunctionOpInterface>(getOperation());
959  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
960  if (argAttr.getName() != spirv::DecorationAttr::name)
961  continue;
962  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
963  return decAttr.getValue() == decoration;
964  }
965  return false;
966  };
967 
968  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969  Type param = fnType.getInputs()[i];
970  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
971  if (!inputPtrType)
972  continue;
973 
974  auto pointeePtrType =
975  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
976  if (pointeePtrType) {
977  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
978  // > If an OpFunctionParameter is a pointer (or contains a pointer)
979  // > and the type it points to is a pointer in the PhysicalStorageBuffer
980  // > storage class, the function parameter must be decorated with exactly
981  // > one of AliasedPointer or RestrictPointer.
982  if (pointeePtrType.getStorageClass() !=
983  spirv::StorageClass::PhysicalStorageBuffer)
984  continue;
985 
986  bool hasAliasedPtr =
987  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988  bool hasRestrictPtr =
989  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990  if (!hasAliasedPtr && !hasRestrictPtr)
991  return emitOpError()
992  << "with a pointer points to a physical buffer pointer must "
993  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
994  continue;
995  }
996  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
997  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
998  // > the PhysicalStorageBuffer storage class, the function parameter must
999  // > be decorated with exactly one of Aliased or Restrict.
1000  if (auto pointeeArrayType =
1001  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1002  pointeePtrType =
1003  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1004  } else {
1005  pointeePtrType = inputPtrType;
1006  }
1007 
1008  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009  spirv::StorageClass::PhysicalStorageBuffer)
1010  continue;
1011 
1012  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014  if (!hasAliased && !hasRestrict)
1015  return emitOpError() << "with physical buffer pointer must be decorated "
1016  "either 'Aliased' or 'Restrict'";
1017  }
1018 
1019  return success();
1020 }
1021 
1022 LogicalResult spirv::FuncOp::verifyBody() {
1023  FunctionType fnType = getFunctionType();
1024 
1025  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1026  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1027  if (fnType.getNumResults() != 0)
1028  return retOp.emitOpError("cannot be used in functions returning value");
1029  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1030  if (fnType.getNumResults() != 1)
1031  return retOp.emitOpError(
1032  "returns 1 value but enclosing function requires ")
1033  << fnType.getNumResults() << " results";
1034 
1035  auto retOperandType = retOp.getValue().getType();
1036  auto fnResultType = fnType.getResult(0);
1037  if (retOperandType != fnResultType)
1038  return retOp.emitOpError(" return value's type (")
1039  << retOperandType << ") mismatch with function's result type ("
1040  << fnResultType << ")";
1041  }
1042  return WalkResult::advance();
1043  });
1044 
1045  // TODO: verify other bits like linkage type.
1046 
1047  return failure(walkResult.wasInterrupted());
1048 }
1049 
1050 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1051  StringRef name, FunctionType type,
1052  spirv::FunctionControl control,
1053  ArrayRef<NamedAttribute> attrs) {
1054  state.addAttribute(SymbolTable::getSymbolAttrName(),
1055  builder.getStringAttr(name));
1056  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1057  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1058  builder.getAttr<spirv::FunctionControlAttr>(control));
1059  state.attributes.append(attrs.begin(), attrs.end());
1060  state.addRegion();
1061 }
1062 
1063 //===----------------------------------------------------------------------===//
1064 // spirv.GLFClampOp
1065 //===----------------------------------------------------------------------===//
1066 
1067 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1068  OperationState &result) {
1069  return parseOneResultSameOperandTypeOp(parser, result);
1070 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // spirv.GLUClampOp
1075 //===----------------------------------------------------------------------===//
1076 
1077 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1078  OperationState &result) {
1079  return parseOneResultSameOperandTypeOp(parser, result);
1080 }
1082 
1083 //===----------------------------------------------------------------------===//
1084 // spirv.GLSClampOp
1085 //===----------------------------------------------------------------------===//
1086 
1087 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1088  OperationState &result) {
1089  return parseOneResultSameOperandTypeOp(parser, result);
1090 }
1092 
1093 //===----------------------------------------------------------------------===//
1094 // spirv.GLFmaOp
1095 //===----------------------------------------------------------------------===//
1096 
1097 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1098  return parseOneResultSameOperandTypeOp(parser, result);
1099 }
1100 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // spirv.GlobalVariable
1104 //===----------------------------------------------------------------------===//
1105 
1106 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1107  Type type, StringRef name,
1108  unsigned descriptorSet, unsigned binding) {
1109  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1110  state.addAttribute(
1111  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1112  builder.getI32IntegerAttr(descriptorSet));
1113  state.addAttribute(
1114  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1115  builder.getI32IntegerAttr(binding));
1116 }
1117 
1118 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1119  Type type, StringRef name,
1120  spirv::BuiltIn builtin) {
1121  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1122  state.addAttribute(
1123  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1124  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1125 }
1126 
1127 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1128  OperationState &result) {
1129  // Parse variable name.
1130  StringAttr nameAttr;
1131  StringRef initializerAttrName =
1132  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1133  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1134  result.attributes)) {
1135  return failure();
1136  }
1137 
1138  // Parse optional initializer
1139  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1140  FlatSymbolRefAttr initSymbol;
1141  if (parser.parseLParen() ||
1142  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1143  result.attributes) ||
1144  parser.parseRParen())
1145  return failure();
1146  }
1147 
1148  if (parseVariableDecorations(parser, result)) {
1149  return failure();
1150  }
1151 
1152  Type type;
1153  StringRef typeAttrName =
1154  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1155  auto loc = parser.getCurrentLocation();
1156  if (parser.parseColonType(type)) {
1157  return failure();
1158  }
1159  if (!llvm::isa<spirv::PointerType>(type)) {
1160  return parser.emitError(loc, "expected spirv.ptr type");
1161  }
1162  result.addAttribute(typeAttrName, TypeAttr::get(type));
1163 
1164  return success();
1165 }
1166 
1168  SmallVector<StringRef, 4> elidedAttrs{
1169  spirv::attributeName<spirv::StorageClass>()};
1170 
1171  // Print variable name.
1172  printer << ' ';
1173  printer.printSymbolName(getSymName());
1174  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1175 
1176  StringRef initializerAttrName = this->getInitializerAttrName();
1177  // Print optional initializer
1178  if (auto initializer = this->getInitializer()) {
1179  printer << " " << initializerAttrName << '(';
1180  printer.printSymbolName(*initializer);
1181  printer << ')';
1182  elidedAttrs.push_back(initializerAttrName);
1183  }
1184 
1185  StringRef typeAttrName = this->getTypeAttrName();
1186  elidedAttrs.push_back(typeAttrName);
1187  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1188  printer << " : " << getType();
1189 }
1190 
1191 LogicalResult spirv::GlobalVariableOp::verify() {
1192  if (!llvm::isa<spirv::PointerType>(getType()))
1193  return emitOpError("result must be of a !spv.ptr type");
1194 
1195  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1196  // object. It cannot be Generic. It must be the same as the Storage Class
1197  // operand of the Result Type."
1198  // Also, Function storage class is reserved by spirv.Variable.
1199  auto storageClass = this->storageClass();
1200  if (storageClass == spirv::StorageClass::Generic ||
1201  storageClass == spirv::StorageClass::Function) {
1202  return emitOpError("storage class cannot be '")
1203  << stringifyStorageClass(storageClass) << "'";
1204  }
1205 
1206  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1207  this->getInitializerAttrName())) {
1209  (*this)->getParentOp(), init.getAttr());
1210  // TODO: Currently only variable initialization with specialization
1211  // constants and other variables is supported. They could be normal
1212  // constants in the module scope as well.
1213  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1214  spirv::SpecConstantCompositeOp>(initOp)) {
1215  return emitOpError("initializer must be result of a "
1216  "spirv.SpecConstant or spirv.GlobalVariable or "
1217  "spirv.SpecConstantCompositeOp op");
1218  }
1219  }
1220 
1221  return success();
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // spirv.INTEL.SubgroupBlockRead
1226 //===----------------------------------------------------------------------===//
1227 
1229  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1230  return failure();
1231 
1232  return success();
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // spirv.INTEL.SubgroupBlockWrite
1237 //===----------------------------------------------------------------------===//
1238 
1240  OperationState &result) {
1241  // Parse the storage class specification
1242  spirv::StorageClass storageClass;
1244  auto loc = parser.getCurrentLocation();
1245  Type elementType;
1246  if (parseEnumStrAttr(storageClass, parser) ||
1247  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1248  parser.parseType(elementType)) {
1249  return failure();
1250  }
1251 
1252  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1253  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1254  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1255 
1256  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1257  result.operands)) {
1258  return failure();
1259  }
1260  return success();
1261 }
1262 
1264  printer << " " << getPtr() << ", " << getValue() << " : "
1265  << getValue().getType();
1266 }
1267 
1269  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1270  return failure();
1271 
1272  return success();
1273 }
1274 
1275 //===----------------------------------------------------------------------===//
1276 // spirv.IAddCarryOp
1277 //===----------------------------------------------------------------------===//
1278 
1279 LogicalResult spirv::IAddCarryOp::verify() {
1281 }
1282 
1283 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1284  OperationState &result) {
1286 }
1287 
1288 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1289  ::printArithmeticExtendedBinaryOp(*this, printer);
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // spirv.ISubBorrowOp
1294 //===----------------------------------------------------------------------===//
1295 
1296 LogicalResult spirv::ISubBorrowOp::verify() {
1298 }
1299 
1300 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1301  OperationState &result) {
1303 }
1304 
1306  ::printArithmeticExtendedBinaryOp(*this, printer);
1307 }
1308 
1309 //===----------------------------------------------------------------------===//
1310 // spirv.SMulExtended
1311 //===----------------------------------------------------------------------===//
1312 
1313 LogicalResult spirv::SMulExtendedOp::verify() {
1315 }
1316 
1317 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1318  OperationState &result) {
1320 }
1321 
1323  ::printArithmeticExtendedBinaryOp(*this, printer);
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // spirv.UMulExtended
1328 //===----------------------------------------------------------------------===//
1329 
1330 LogicalResult spirv::UMulExtendedOp::verify() {
1332 }
1333 
1334 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1335  OperationState &result) {
1337 }
1338 
1340  ::printArithmeticExtendedBinaryOp(*this, printer);
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // spirv.MemoryBarrierOp
1345 //===----------------------------------------------------------------------===//
1346 
1347 LogicalResult spirv::MemoryBarrierOp::verify() {
1348  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1349 }
1350 
1351 //===----------------------------------------------------------------------===//
1352 // spirv.module
1353 //===----------------------------------------------------------------------===//
1354 
1355 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1356  std::optional<StringRef> name) {
1357  OpBuilder::InsertionGuard guard(builder);
1358  builder.createBlock(state.addRegion());
1359  if (name) {
1360  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1361  builder.getStringAttr(*name));
1362  }
1363 }
1364 
1365 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1366  spirv::AddressingModel addressingModel,
1367  spirv::MemoryModel memoryModel,
1368  std::optional<VerCapExtAttr> vceTriple,
1369  std::optional<StringRef> name) {
1370  state.addAttribute(
1371  "addressing_model",
1372  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1373  state.addAttribute("memory_model",
1374  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1375  OpBuilder::InsertionGuard guard(builder);
1376  builder.createBlock(state.addRegion());
1377  if (vceTriple)
1378  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1379  if (name)
1380  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1381  builder.getStringAttr(*name));
1382 }
1383 
1384 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1385  OperationState &result) {
1386  Region *body = result.addRegion();
1387 
1388  // If the name is present, parse it.
1389  StringAttr nameAttr;
1390  (void)parser.parseOptionalSymbolName(
1391  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1392 
1393  // Parse attributes
1394  spirv::AddressingModel addrModel;
1395  spirv::MemoryModel memoryModel;
1396  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1397  result) ||
1398  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1399  result))
1400  return failure();
1401 
1402  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1403  spirv::VerCapExtAttr vceTriple;
1404  if (parser.parseAttribute(vceTriple,
1405  spirv::ModuleOp::getVCETripleAttrName(),
1406  result.attributes))
1407  return failure();
1408  }
1409 
1410  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1411  parser.parseRegion(*body, /*arguments=*/{}))
1412  return failure();
1413 
1414  // Make sure we have at least one block.
1415  if (body->empty())
1416  body->push_back(new Block());
1417 
1418  return success();
1419 }
1420 
1421 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1422  if (std::optional<StringRef> name = getName()) {
1423  printer << ' ';
1424  printer.printSymbolName(*name);
1425  }
1426 
1427  SmallVector<StringRef, 2> elidedAttrs;
1428 
1429  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1430  << spirv::stringifyMemoryModel(getMemoryModel());
1431  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1432  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1433  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1435 
1436  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1437  printer << " requires " << *triple;
1438  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1439  }
1440 
1441  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1442  printer << ' ';
1443  printer.printRegion(getRegion());
1444 }
1445 
1446 LogicalResult spirv::ModuleOp::verifyRegions() {
1447  Dialect *dialect = (*this)->getDialect();
1449  entryPoints;
1450  mlir::SymbolTable table(*this);
1451 
1452  for (auto &op : *getBody()) {
1453  if (op.getDialect() != dialect)
1454  return op.emitError("'spirv.module' can only contain spirv.* ops");
1455 
1456  // For EntryPoint op, check that the function and execution model is not
1457  // duplicated in EntryPointOps. Also verify that the interface specified
1458  // comes from globalVariables here to make this check cheaper.
1459  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1460  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1461  if (!funcOp) {
1462  return entryPointOp.emitError("function '")
1463  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1464  }
1465  if (auto interface = entryPointOp.getInterface()) {
1466  for (Attribute varRef : interface) {
1467  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1468  if (!varSymRef) {
1469  return entryPointOp.emitError(
1470  "expected symbol reference for interface "
1471  "specification instead of '")
1472  << varRef;
1473  }
1474  auto variableOp =
1475  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1476  if (!variableOp) {
1477  return entryPointOp.emitError("expected spirv.GlobalVariable "
1478  "symbol reference instead of'")
1479  << varSymRef << "'";
1480  }
1481  }
1482  }
1483 
1484  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1485  funcOp, entryPointOp.getExecutionModel());
1486  if (!entryPoints.try_emplace(key, entryPointOp).second)
1487  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1488  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1489  // If the function is external and does not have 'Import'
1490  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1491  // LinkageAttributes is used to import external functions.
1492  auto linkageAttr = funcOp.getLinkageAttributes();
1493  auto hasImportLinkage =
1494  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1495  spirv::LinkageType::Import);
1496  if (funcOp.isExternal() && !hasImportLinkage)
1497  return op.emitError(
1498  "'spirv.module' cannot contain external functions "
1499  "without 'Import' linkage_attributes (LinkageAttributes)");
1500 
1501  // TODO: move this check to spirv.func.
1502  for (auto &block : funcOp)
1503  for (auto &op : block) {
1504  if (op.getDialect() != dialect)
1505  return op.emitError(
1506  "functions in 'spirv.module' can only contain spirv.* ops");
1507  }
1508  }
1509  }
1510 
1511  return success();
1512 }
1513 
1514 //===----------------------------------------------------------------------===//
1515 // spirv.mlir.referenceof
1516 //===----------------------------------------------------------------------===//
1517 
1518 LogicalResult spirv::ReferenceOfOp::verify() {
1519  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1520  (*this)->getParentOp(), getSpecConstAttr());
1521  Type constType;
1522 
1523  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1524  if (specConstOp)
1525  constType = specConstOp.getDefaultValue().getType();
1526 
1527  auto specConstCompositeOp =
1528  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1529  if (specConstCompositeOp)
1530  constType = specConstCompositeOp.getType();
1531 
1532  if (!specConstOp && !specConstCompositeOp)
1533  return emitOpError(
1534  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1535 
1536  if (getReference().getType() != constType)
1537  return emitOpError("result type mismatch with the referenced "
1538  "specialization constant's type");
1539 
1540  return success();
1541 }
1542 
1543 //===----------------------------------------------------------------------===//
1544 // spirv.SpecConstant
1545 //===----------------------------------------------------------------------===//
1546 
1547 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1548  OperationState &result) {
1549  StringAttr nameAttr;
1550  Attribute valueAttr;
1551  StringRef defaultValueAttrName =
1552  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1553 
1554  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1555  result.attributes))
1556  return failure();
1557 
1558  // Parse optional spec_id.
1559  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1560  IntegerAttr specIdAttr;
1561  if (parser.parseLParen() ||
1562  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1563  parser.parseRParen())
1564  return failure();
1565  }
1566 
1567  if (parser.parseEqual() ||
1568  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1569  return failure();
1570 
1571  return success();
1572 }
1573 
1575  printer << ' ';
1576  printer.printSymbolName(getSymName());
1577  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1578  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1579  printer << " = " << getDefaultValue();
1580 }
1581 
1582 LogicalResult spirv::SpecConstantOp::verify() {
1583  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1584  if (specID.getValue().isNegative())
1585  return emitOpError("SpecId cannot be negative");
1586 
1587  auto value = getDefaultValue();
1588  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1589  // Make sure bitwidth is allowed.
1590  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1591  return emitOpError("default value bitwidth disallowed");
1592  return success();
1593  }
1594  return emitOpError(
1595  "default value can only be a bool, integer, or float scalar");
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // spirv.VectorShuffle
1600 //===----------------------------------------------------------------------===//
1601 
1602 LogicalResult spirv::VectorShuffleOp::verify() {
1603  VectorType resultType = llvm::cast<VectorType>(getType());
1604 
1605  size_t numResultElements = resultType.getNumElements();
1606  if (numResultElements != getComponents().size())
1607  return emitOpError("result type element count (")
1608  << numResultElements
1609  << ") mismatch with the number of component selectors ("
1610  << getComponents().size() << ")";
1611 
1612  size_t totalSrcElements =
1613  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1614  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1615 
1616  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1617  uint32_t index = selector.getZExtValue();
1618  if (index >= totalSrcElements &&
1619  index != std::numeric_limits<uint32_t>().max())
1620  return emitOpError("component selector ")
1621  << index << " out of range: expected to be in [0, "
1622  << totalSrcElements << ") or 0xffffffff";
1623  }
1624  return success();
1625 }
1626 
1627 //===----------------------------------------------------------------------===//
1628 // spirv.MatrixTimesScalar
1629 //===----------------------------------------------------------------------===//
1630 
1631 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1632  Type elementType =
1633  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1635  [](auto matrixType) { return matrixType.getElementType(); })
1636  .Default([](Type) { return nullptr; });
1637 
1638  assert(elementType && "Unhandled type");
1639 
1640  // Check that the scalar type is the same as the matrix element type.
1641  if (getScalar().getType() != elementType)
1642  return emitOpError("input matrix components' type and scaling value must "
1643  "have the same type");
1644 
1645  return success();
1646 }
1647 
1648 //===----------------------------------------------------------------------===//
1649 // spirv.Transpose
1650 //===----------------------------------------------------------------------===//
1651 
1652 LogicalResult spirv::TransposeOp::verify() {
1653  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1654  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1655 
1656  // Verify that the input and output matrices have correct shapes.
1657  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1658  return emitError("input matrix rows count must be equal to "
1659  "output matrix columns count");
1660 
1661  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1662  return emitError("input matrix columns count must be equal to "
1663  "output matrix rows count");
1664 
1665  // Verify that the input and output matrices have the same component type
1666  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1667  return emitError("input and output matrices must have the same "
1668  "component type");
1669 
1670  return success();
1671 }
1672 
1673 //===----------------------------------------------------------------------===//
1674 // spirv.MatrixTimesVector
1675 //===----------------------------------------------------------------------===//
1676 
1677 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1678  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1679  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1680  auto resultType = llvm::cast<VectorType>(getType());
1681 
1682  if (matrixType.getNumColumns() != vectorType.getNumElements())
1683  return emitOpError("matrix columns (")
1684  << matrixType.getNumColumns() << ") must match vector operand size ("
1685  << vectorType.getNumElements() << ")";
1686 
1687  if (resultType.getNumElements() != matrixType.getNumRows())
1688  return emitOpError("result size (")
1689  << resultType.getNumElements() << ") must match the matrix rows ("
1690  << matrixType.getNumRows() << ")";
1691 
1692  if (matrixType.getElementType() != resultType.getElementType())
1693  return emitOpError("matrix and result element types must match");
1694 
1695  return success();
1696 }
1697 
1698 //===----------------------------------------------------------------------===//
1699 // spirv.VectorTimesMatrix
1700 //===----------------------------------------------------------------------===//
1701 
1702 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1703  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1704  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1705  auto resultType = llvm::cast<VectorType>(getType());
1706 
1707  if (matrixType.getNumRows() != vectorType.getNumElements())
1708  return emitOpError("number of components in vector must equal the number "
1709  "of components in each column in matrix");
1710 
1711  if (resultType.getNumElements() != matrixType.getNumColumns())
1712  return emitOpError("number of columns in matrix must equal the number of "
1713  "components in result");
1714 
1715  if (matrixType.getElementType() != resultType.getElementType())
1716  return emitOpError("matrix must be a matrix with the same component type "
1717  "as the component type in result");
1718 
1719  return success();
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // spirv.MatrixTimesMatrix
1724 //===----------------------------------------------------------------------===//
1725 
1726 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1727  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1728  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1729  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1730 
1731  // left matrix columns' count and right matrix rows' count must be equal
1732  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1733  return emitError("left matrix columns' count must be equal to "
1734  "the right matrix rows' count");
1735 
1736  // right and result matrices columns' count must be the same
1737  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1738  return emitError(
1739  "right and result matrices must have equal columns' count");
1740 
1741  // right and result matrices component type must be the same
1742  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1743  return emitError("right and result matrices' component type must"
1744  " be the same");
1745 
1746  // left and result matrices component type must be the same
1747  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1748  return emitError("left and result matrices' component type"
1749  " must be the same");
1750 
1751  // left and result matrices rows count must be the same
1752  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1753  return emitError("left and result matrices must have equal rows' count");
1754 
1755  return success();
1756 }
1757 
1758 //===----------------------------------------------------------------------===//
1759 // spirv.SpecConstantComposite
1760 //===----------------------------------------------------------------------===//
1761 
1763  OperationState &result) {
1764 
1765  StringAttr compositeName;
1766  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1767  result.attributes))
1768  return failure();
1769 
1770  if (parser.parseLParen())
1771  return failure();
1772 
1773  SmallVector<Attribute, 4> constituents;
1774 
1775  do {
1776  // The name of the constituent attribute isn't important
1777  const char *attrName = "spec_const";
1778  FlatSymbolRefAttr specConstRef;
1779  NamedAttrList attrs;
1780 
1781  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1782  return failure();
1783 
1784  constituents.push_back(specConstRef);
1785  } while (!parser.parseOptionalComma());
1786 
1787  if (parser.parseRParen())
1788  return failure();
1789 
1790  StringAttr compositeSpecConstituentsName =
1791  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1792  result.addAttribute(compositeSpecConstituentsName,
1793  parser.getBuilder().getArrayAttr(constituents));
1794 
1795  Type type;
1796  if (parser.parseColonType(type))
1797  return failure();
1798 
1799  StringAttr typeAttrName =
1800  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1801  result.addAttribute(typeAttrName, TypeAttr::get(type));
1802 
1803  return success();
1804 }
1805 
1807  printer << " ";
1808  printer.printSymbolName(getSymName());
1809  printer << " (";
1810  auto constituents = this->getConstituents().getValue();
1811 
1812  if (!constituents.empty())
1813  llvm::interleaveComma(constituents, printer);
1814 
1815  printer << ") : " << getType();
1816 }
1817 
1819  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1820  auto constituents = this->getConstituents().getValue();
1821 
1822  if (!cType)
1823  return emitError("result type must be a composite type, but provided ")
1824  << getType();
1825 
1826  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1827  return emitError("unsupported composite type ") << cType;
1828  if (constituents.size() != cType.getNumElements())
1829  return emitError("has incorrect number of operands: expected ")
1830  << cType.getNumElements() << ", but provided "
1831  << constituents.size();
1832 
1833  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1834  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1835 
1836  auto constituentSpecConstOp =
1837  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1838  (*this)->getParentOp(), constituent.getAttr()));
1839 
1840  if (constituentSpecConstOp.getDefaultValue().getType() !=
1841  cType.getElementType(index))
1842  return emitError("has incorrect types of operands: expected ")
1843  << cType.getElementType(index) << ", but provided "
1844  << constituentSpecConstOp.getDefaultValue().getType();
1845  }
1846 
1847  return success();
1848 }
1849 
1850 //===----------------------------------------------------------------------===//
1851 // spirv.SpecConstantOperation
1852 //===----------------------------------------------------------------------===//
1853 
1855  OperationState &result) {
1856  Region *body = result.addRegion();
1857 
1858  if (parser.parseKeyword("wraps"))
1859  return failure();
1860 
1861  body->push_back(new Block);
1862  Block &block = body->back();
1863  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1864 
1865  if (!wrappedOp)
1866  return failure();
1867 
1868  OpBuilder builder(parser.getContext());
1869  builder.setInsertionPointToEnd(&block);
1870  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1871  result.location = wrappedOp->getLoc();
1872 
1873  result.addTypes(wrappedOp->getResult(0).getType());
1874 
1875  if (parser.parseOptionalAttrDict(result.attributes))
1876  return failure();
1877 
1878  return success();
1879 }
1880 
1882  printer << " wraps ";
1883  printer.printGenericOp(&getBody().front().front());
1884 }
1885 
1886 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1887  Block &block = getRegion().getBlocks().front();
1888 
1889  if (block.getOperations().size() != 2)
1890  return emitOpError("expected exactly 2 nested ops");
1891 
1892  Operation &enclosedOp = block.getOperations().front();
1893 
1895  return emitOpError("invalid enclosed op");
1896 
1897  for (auto operand : enclosedOp.getOperands())
1898  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1899  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1900  return emitOpError(
1901  "invalid operand, must be defined by a constant operation");
1902 
1903  return success();
1904 }
1905 
1906 //===----------------------------------------------------------------------===//
1907 // spirv.GL.FrexpStruct
1908 //===----------------------------------------------------------------------===//
1909 
1910 LogicalResult spirv::GLFrexpStructOp::verify() {
1911  spirv::StructType structTy =
1912  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1913 
1914  if (structTy.getNumElements() != 2)
1915  return emitError("result type must be a struct type with two memebers");
1916 
1917  Type significandTy = structTy.getElementType(0);
1918  Type exponentTy = structTy.getElementType(1);
1919  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1920  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1921 
1922  Type operandTy = getOperand().getType();
1923  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1924  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1925 
1926  if (significandTy != operandTy)
1927  return emitError("member zero of the resulting struct type must be the "
1928  "same type as the operand");
1929 
1930  if (exponentVecTy) {
1931  IntegerType componentIntTy =
1932  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1933  if (!componentIntTy || componentIntTy.getWidth() != 32)
1934  return emitError("member one of the resulting struct type must"
1935  "be a scalar or vector of 32 bit integer type");
1936  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1937  return emitError("member one of the resulting struct type "
1938  "must be a scalar or vector of 32 bit integer type");
1939  }
1940 
1941  // Check that the two member types have the same number of components
1942  if (operandVecTy && exponentVecTy &&
1943  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1944  return success();
1945 
1946  if (operandFTy && exponentIntTy)
1947  return success();
1948 
1949  return emitError("member one of the resulting struct type must have the same "
1950  "number of components as the operand type");
1951 }
1952 
1953 //===----------------------------------------------------------------------===//
1954 // spirv.GL.Ldexp
1955 //===----------------------------------------------------------------------===//
1956 
1957 LogicalResult spirv::GLLdexpOp::verify() {
1958  Type significandType = getX().getType();
1959  Type exponentType = getExp().getType();
1960 
1961  if (llvm::isa<FloatType>(significandType) !=
1962  llvm::isa<IntegerType>(exponentType))
1963  return emitOpError("operands must both be scalars or vectors");
1964 
1965  auto getNumElements = [](Type type) -> unsigned {
1966  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1967  return vectorType.getNumElements();
1968  return 1;
1969  };
1970 
1971  if (getNumElements(significandType) != getNumElements(exponentType))
1972  return emitOpError("operands must have the same number of elements");
1973 
1974  return success();
1975 }
1976 
1977 //===----------------------------------------------------------------------===//
1978 // spirv.ShiftLeftLogicalOp
1979 //===----------------------------------------------------------------------===//
1980 
1981 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
1982  return verifyShiftOp(*this);
1983 }
1984 
1985 //===----------------------------------------------------------------------===//
1986 // spirv.ShiftRightArithmeticOp
1987 //===----------------------------------------------------------------------===//
1988 
1989 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
1990  return verifyShiftOp(*this);
1991 }
1992 
1993 //===----------------------------------------------------------------------===//
1994 // spirv.ShiftRightLogicalOp
1995 //===----------------------------------------------------------------------===//
1996 
1997 LogicalResult spirv::ShiftRightLogicalOp::verify() {
1998  return verifyShiftOp(*this);
1999 }
2000 
2001 //===----------------------------------------------------------------------===//
2002 // spirv.VectorTimesScalarOp
2003 //===----------------------------------------------------------------------===//
2004 
2005 LogicalResult spirv::VectorTimesScalarOp::verify() {
2006  if (getVector().getType() != getType())
2007  return emitOpError("vector operand and result type mismatch");
2008  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2009  if (getScalar().getType() != scalarType)
2010  return emitOpError("scalar operand and result element type match");
2011  return success();
2012 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:268
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:558
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:121
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:254
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:300
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:170
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:150
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:292
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:272
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
Type front()
Return first type in the range.
Definition: TypeRange.h:149
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
SPIR-V struct type.
Definition: SPIRVTypes.h:293
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:70
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:94
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:50
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.