MLIR  19.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
33 #include "llvm/ADT/APFloat.h"
34 #include "llvm/ADT/APInt.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include <cassert>
40 #include <numeric>
41 #include <optional>
42 #include <type_traits>
43 
44 using namespace mlir;
45 using namespace mlir::spirv::AttrNames;
46 
47 //===----------------------------------------------------------------------===//
48 // Common utility functions
49 //===----------------------------------------------------------------------===//
50 
52  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53  if (!constOp) {
54  return failure();
55  }
56  auto valueAttr = constOp.getValue();
57  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58  if (!integerValueAttr) {
59  return failure();
60  }
61 
62  if (integerValueAttr.getType().isSignlessInteger())
63  value = integerValueAttr.getInt();
64  else
65  value = integerValueAttr.getSInt();
66 
67  return success();
68 }
69 
72  spirv::MemorySemantics memorySemantics) {
73  // According to the SPIR-V specification:
74  // "Despite being a mask and allowing multiple bits to be combined, it is
75  // invalid for more than one of these four bits to be set: Acquire, Release,
76  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77  // Release semantics is done by setting the AcquireRelease bit, not by setting
78  // two bits."
79  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80  spirv::MemorySemantics::Release |
81  spirv::MemorySemantics::AcquireRelease |
82  spirv::MemorySemantics::SequentiallyConsistent;
83 
84  auto bitCount =
85  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86  if (bitCount > 1) {
87  return op->emitError(
88  "expected at most one of these four memory constraints "
89  "to be set: `Acquire`, `Release`,"
90  "`AcquireRelease` or `SequentiallyConsistent`");
91  }
92  return success();
93 }
94 
96  SmallVectorImpl<StringRef> &elidedAttrs) {
97  // Print optional descriptor binding
98  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::DescriptorSet));
100  auto bindingName = llvm::convertToSnakeFromCamelCase(
101  stringifyDecoration(spirv::Decoration::Binding));
102  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104  if (descriptorSet && binding) {
105  elidedAttrs.push_back(descriptorSetName);
106  elidedAttrs.push_back(bindingName);
107  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108  << ")";
109  }
110 
111  // Print BuiltIn attribute if present
112  auto builtInName = llvm::convertToSnakeFromCamelCase(
113  stringifyDecoration(spirv::Decoration::BuiltIn));
114  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116  elidedAttrs.push_back(builtInName);
117  }
118 
119  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
120 }
121 
123  OperationState &result) {
125  Type type;
126  // If the operand list is in-between parentheses, then we have a generic form.
127  // (see the fallback in `printOneResultOp`).
128  SMLoc loc = parser.getCurrentLocation();
129  if (!parser.parseOptionalLParen()) {
130  if (parser.parseOperandList(ops) || parser.parseRParen() ||
131  parser.parseOptionalAttrDict(result.attributes) ||
132  parser.parseColon() || parser.parseType(type))
133  return failure();
134  auto fnType = llvm::dyn_cast<FunctionType>(type);
135  if (!fnType) {
136  parser.emitError(loc, "expected function type");
137  return failure();
138  }
139  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140  return failure();
141  result.addTypes(fnType.getResults());
142  return success();
143  }
144  return failure(parser.parseOperandList(ops) ||
145  parser.parseOptionalAttrDict(result.attributes) ||
146  parser.parseColonType(type) ||
147  parser.resolveOperands(ops, type, result.operands) ||
148  parser.addTypeToList(type, result.types));
149 }
150 
152  assert(op->getNumResults() == 1 && "op should have one result");
153 
154  // If not all the operand and result types are the same, just use the
155  // generic assembly form to avoid omitting information in printing.
156  auto resultType = op->getResult(0).getType();
157  if (llvm::any_of(op->getOperandTypes(),
158  [&](Type type) { return type != resultType; })) {
159  p.printGenericOp(op, /*printOpName=*/false);
160  return;
161  }
162 
163  p << ' ';
164  p.printOperands(op->getOperands());
166  // Now we can output only one type for all operands and the result.
167  p << " : " << resultType;
168 }
169 
170 template <typename Op>
172  spirv::ImageOperandsAttr attr,
173  Operation::operand_range operands) {
174  if (!attr) {
175  if (operands.empty())
176  return success();
177 
178  return imageOp.emitError("the Image Operands should encode what operands "
179  "follow, as per Image Operands");
180  }
181 
182  // TODO: Add the validation rules for the following Image Operands.
183  spirv::ImageOperands noSupportOperands =
184  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
185  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
186  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
187  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
188  spirv::ImageOperands::MakeTexelAvailable |
189  spirv::ImageOperands::MakeTexelVisible |
190  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
191 
192  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
193  llvm_unreachable("unimplemented operands of Image Operands");
194 
195  return success();
196 }
197 
198 template <typename BlockReadWriteOpTy>
200  Value ptr, Value val) {
201  auto valType = val.getType();
202  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
203  valType = valVecTy.getElementType();
204 
205  if (valType !=
206  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
207  return op.emitOpError("mismatch in result type and pointer type");
208  }
209  return success();
210 }
211 
212 /// Walks the given type hierarchy with the given indices, potentially down
213 /// to component granularity, to select an element type. Returns null type and
214 /// emits errors with the given loc on failure.
215 static Type
217  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
218  if (indices.empty()) {
219  emitErrorFn("expected at least one index for spirv.CompositeExtract");
220  return nullptr;
221  }
222 
223  for (auto index : indices) {
224  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
225  if (cType.hasCompileTimeKnownNumElements() &&
226  (index < 0 ||
227  static_cast<uint64_t>(index) >= cType.getNumElements())) {
228  emitErrorFn("index ") << index << " out of bounds for " << type;
229  return nullptr;
230  }
231  type = cType.getElementType(index);
232  } else {
233  emitErrorFn("cannot extract from non-composite type ")
234  << type << " with index " << index;
235  return nullptr;
236  }
237  }
238  return type;
239 }
240 
241 static Type
243  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
244  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
245  if (!indicesArrayAttr) {
246  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
247  return nullptr;
248  }
249  if (indicesArrayAttr.empty()) {
250  emitErrorFn("expected at least one index for spirv.CompositeExtract");
251  return nullptr;
252  }
253 
254  SmallVector<int32_t, 2> indexVals;
255  for (auto indexAttr : indicesArrayAttr) {
256  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
257  if (!indexIntAttr) {
258  emitErrorFn("expected an 32-bit integer for index, but found '")
259  << indexAttr << "'";
260  return nullptr;
261  }
262  indexVals.push_back(indexIntAttr.getInt());
263  }
264  return getElementType(type, indexVals, emitErrorFn);
265 }
266 
267 static Type getElementType(Type type, Attribute indices, Location loc) {
268  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
269  return ::mlir::emitError(loc, err);
270  };
271  return getElementType(type, indices, errorFn);
272 }
273 
274 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
275  SMLoc loc) {
276  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
277  return parser.emitError(loc, err);
278  };
279  return getElementType(type, indices, errorFn);
280 }
281 
282 template <typename ExtendedBinaryOp>
283 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
284  auto resultType = llvm::cast<spirv::StructType>(op.getType());
285  if (resultType.getNumElements() != 2)
286  return op.emitOpError("expected result struct type containing two members");
287 
288  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
289  resultType.getElementType(0),
290  resultType.getElementType(1)}))
291  return op.emitOpError(
292  "expected all operand types and struct member types are the same");
293 
294  return success();
295 }
296 
298  OperationState &result) {
300  if (parser.parseOptionalAttrDict(result.attributes) ||
301  parser.parseOperandList(operands) || parser.parseColon())
302  return failure();
303 
304  Type resultType;
305  SMLoc loc = parser.getCurrentLocation();
306  if (parser.parseType(resultType))
307  return failure();
308 
309  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
310  if (!structType || structType.getNumElements() != 2)
311  return parser.emitError(loc, "expected spirv.struct type with two members");
312 
313  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
314  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
315  return failure();
316 
317  result.addTypes(resultType);
318  return success();
319 }
320 
322  OpAsmPrinter &printer) {
323  printer << ' ';
324  printer.printOptionalAttrDict(op->getAttrs());
325  printer.printOperands(op->getOperands());
326  printer << " : " << op->getResultTypes().front();
327 }
328 
330  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
331  return op->emitError("expected the same type for the first operand and "
332  "result, but provided ")
333  << op->getOperand(0).getType() << " and "
334  << op->getResult(0).getType();
335  }
336  return success();
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // spirv.mlir.addressof
341 //===----------------------------------------------------------------------===//
342 
343 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
344  spirv::GlobalVariableOp var) {
345  build(builder, state, var.getType(), SymbolRefAttr::get(var));
346 }
347 
349  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
350  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
351  getVariableAttr()));
352  if (!varOp) {
353  return emitOpError("expected spirv.GlobalVariable symbol");
354  }
355  if (getPointer().getType() != varOp.getType()) {
356  return emitOpError(
357  "result type mismatch with the referenced global variable's type");
358  }
359  return success();
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // spirv.CompositeConstruct
364 //===----------------------------------------------------------------------===//
365 
367  operand_range constituents = this->getConstituents();
368 
369  // There are 4 cases with varying verification rules:
370  // 1. Cooperative Matrices (1 constituent)
371  // 2. Structs (1 constituent for each member)
372  // 3. Arrays (1 constituent for each array element)
373  // 4. Vectors (1 constituent (sub-)element for each vector element)
374 
375  auto coopElementType =
378  [](auto coopType) { return coopType.getElementType(); })
379  .Default([](Type) { return nullptr; });
380 
381  // Case 1. -- matrices.
382  if (coopElementType) {
383  if (constituents.size() != 1)
384  return emitOpError("has incorrect number of operands: expected ")
385  << "1, but provided " << constituents.size();
386  if (coopElementType != constituents.front().getType())
387  return emitOpError("operand type mismatch: expected operand type ")
388  << coopElementType << ", but provided "
389  << constituents.front().getType();
390  return success();
391  }
392 
393  // Case 2./3./4. -- number of constituents matches the number of elements.
394  auto cType = llvm::cast<spirv::CompositeType>(getType());
395  if (constituents.size() == cType.getNumElements()) {
396  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
397  if (constituents[index].getType() != cType.getElementType(index)) {
398  return emitOpError("operand type mismatch: expected operand type ")
399  << cType.getElementType(index) << ", but provided "
400  << constituents[index].getType();
401  }
402  }
403  return success();
404  }
405 
406  // Case 4. -- check that all constituents add up tp the expected vector type.
407  auto resultType = llvm::dyn_cast<VectorType>(cType);
408  if (!resultType)
409  return emitOpError(
410  "expected to return a vector or cooperative matrix when the number of "
411  "constituents is less than what the result needs");
412 
413  SmallVector<unsigned> sizes;
414  for (Value component : constituents) {
415  if (!llvm::isa<VectorType>(component.getType()) &&
416  !component.getType().isIntOrFloat())
417  return emitOpError("operand type mismatch: expected operand to have "
418  "a scalar or vector type, but provided ")
419  << component.getType();
420 
421  Type elementType = component.getType();
422  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
423  sizes.push_back(vectorType.getNumElements());
424  elementType = vectorType.getElementType();
425  } else {
426  sizes.push_back(1);
427  }
428 
429  if (elementType != resultType.getElementType())
430  return emitOpError("operand element type mismatch: expected to be ")
431  << resultType.getElementType() << ", but provided " << elementType;
432  }
433  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
434  if (totalCount != cType.getNumElements())
435  return emitOpError("has incorrect number of operands: expected ")
436  << cType.getNumElements() << ", but provided " << totalCount;
437  return success();
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // spirv.CompositeExtractOp
442 //===----------------------------------------------------------------------===//
443 
444 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
445  Value composite,
446  ArrayRef<int32_t> indices) {
447  auto indexAttr = builder.getI32ArrayAttr(indices);
448  auto elementType =
449  getElementType(composite.getType(), indexAttr, state.location);
450  if (!elementType) {
451  return;
452  }
453  build(builder, state, elementType, composite, indexAttr);
454 }
455 
457  OperationState &result) {
458  OpAsmParser::UnresolvedOperand compositeInfo;
459  Attribute indicesAttr;
460  StringRef indicesAttrName =
461  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
462  Type compositeType;
463  SMLoc attrLocation;
464 
465  if (parser.parseOperand(compositeInfo) ||
466  parser.getCurrentLocation(&attrLocation) ||
467  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
468  parser.parseColonType(compositeType) ||
469  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
470  return failure();
471  }
472 
473  Type resultType =
474  getElementType(compositeType, indicesAttr, parser, attrLocation);
475  if (!resultType) {
476  return failure();
477  }
478  result.addTypes(resultType);
479  return success();
480 }
481 
483  printer << ' ' << getComposite() << getIndices() << " : "
484  << getComposite().getType();
485 }
486 
488  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
489  auto resultType =
490  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
491  if (!resultType)
492  return failure();
493 
494  if (resultType != getType()) {
495  return emitOpError("invalid result type: expected ")
496  << resultType << " but provided " << getType();
497  }
498 
499  return success();
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // spirv.CompositeInsert
504 //===----------------------------------------------------------------------===//
505 
506 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
507  Value object, Value composite,
508  ArrayRef<int32_t> indices) {
509  auto indexAttr = builder.getI32ArrayAttr(indices);
510  build(builder, state, composite.getType(), object, composite, indexAttr);
511 }
512 
514  OperationState &result) {
516  Type objectType, compositeType;
517  Attribute indicesAttr;
518  StringRef indicesAttrName =
519  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
520  auto loc = parser.getCurrentLocation();
521 
522  return failure(
523  parser.parseOperandList(operands, 2) ||
524  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
525  parser.parseColonType(objectType) ||
526  parser.parseKeywordType("into", compositeType) ||
527  parser.resolveOperands(operands, {objectType, compositeType}, loc,
528  result.operands) ||
529  parser.addTypesToList(compositeType, result.types));
530 }
531 
533  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
534  auto objectType =
535  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
536  if (!objectType)
537  return failure();
538 
539  if (objectType != getObject().getType()) {
540  return emitOpError("object operand type should be ")
541  << objectType << ", but found " << getObject().getType();
542  }
543 
544  if (getComposite().getType() != getType()) {
545  return emitOpError("result type should be the same as "
546  "the composite type, but found ")
547  << getComposite().getType() << " vs " << getType();
548  }
549 
550  return success();
551 }
552 
554  printer << " " << getObject() << ", " << getComposite() << getIndices()
555  << " : " << getObject().getType() << " into "
556  << getComposite().getType();
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // spirv.Constant
561 //===----------------------------------------------------------------------===//
562 
564  OperationState &result) {
565  Attribute value;
566  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
567  if (parser.parseAttribute(value, valueAttrName, result.attributes))
568  return failure();
569 
570  Type type = NoneType::get(parser.getContext());
571  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
572  type = typedAttr.getType();
573  if (llvm::isa<NoneType, TensorType>(type)) {
574  if (parser.parseColonType(type))
575  return failure();
576  }
577 
578  return parser.addTypeToList(type, result.types);
579 }
580 
581 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
582  printer << ' ' << getValue();
583  if (llvm::isa<spirv::ArrayType>(getType()))
584  printer << " : " << getType();
585 }
586 
587 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
588  Type opType) {
589  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
590  auto valueType = llvm::cast<TypedAttr>(value).getType();
591  if (valueType != opType)
592  return op.emitOpError("result type (")
593  << opType << ") does not match value type (" << valueType << ")";
594  return success();
595  }
596  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
597  auto valueType = llvm::cast<TypedAttr>(value).getType();
598  if (valueType == opType)
599  return success();
600  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
601  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
602  if (!arrayType)
603  return op.emitOpError("result or element type (")
604  << opType << ") does not match value type (" << valueType
605  << "), must be the same or spirv.array";
606 
607  int numElements = arrayType.getNumElements();
608  auto opElemType = arrayType.getElementType();
609  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
610  numElements *= t.getNumElements();
611  opElemType = t.getElementType();
612  }
613  if (!opElemType.isIntOrFloat())
614  return op.emitOpError("only support nested array result type");
615 
616  auto valueElemType = shapedType.getElementType();
617  if (valueElemType != opElemType) {
618  return op.emitOpError("result element type (")
619  << opElemType << ") does not match value element type ("
620  << valueElemType << ")";
621  }
622 
623  if (numElements != shapedType.getNumElements()) {
624  return op.emitOpError("result number of elements (")
625  << numElements << ") does not match value number of elements ("
626  << shapedType.getNumElements() << ")";
627  }
628  return success();
629  }
630  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
631  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
632  if (!arrayType)
633  return op.emitOpError(
634  "must have spirv.array result type for array value");
635  Type elemType = arrayType.getElementType();
636  for (Attribute element : arrayAttr.getValue()) {
637  // Verify array elements recursively.
638  if (failed(verifyConstantType(op, element, elemType)))
639  return failure();
640  }
641  return success();
642  }
643  return op.emitOpError("cannot have attribute: ") << value;
644 }
645 
647  // ODS already generates checks to make sure the result type is valid. We just
648  // need to additionally check that the value's attribute type is consistent
649  // with the result type.
650  return verifyConstantType(*this, getValueAttr(), getType());
651 }
652 
653 bool spirv::ConstantOp::isBuildableWith(Type type) {
654  // Must be valid SPIR-V type first.
655  if (!llvm::isa<spirv::SPIRVType>(type))
656  return false;
657 
658  if (isa<SPIRVDialect>(type.getDialect())) {
659  // TODO: support constant struct
660  return llvm::isa<spirv::ArrayType>(type);
661  }
662 
663  return true;
664 }
665 
666 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
667  OpBuilder &builder) {
668  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
669  unsigned width = intType.getWidth();
670  if (width == 1)
671  return builder.create<spirv::ConstantOp>(loc, type,
672  builder.getBoolAttr(false));
673  return builder.create<spirv::ConstantOp>(
674  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
675  }
676  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
677  return builder.create<spirv::ConstantOp>(
678  loc, type, builder.getFloatAttr(floatType, 0.0));
679  }
680  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
681  Type elemType = vectorType.getElementType();
682  if (llvm::isa<IntegerType>(elemType)) {
683  return builder.create<spirv::ConstantOp>(
684  loc, type,
685  DenseElementsAttr::get(vectorType,
686  IntegerAttr::get(elemType, 0).getValue()));
687  }
688  if (llvm::isa<FloatType>(elemType)) {
689  return builder.create<spirv::ConstantOp>(
690  loc, type,
691  DenseFPElementsAttr::get(vectorType,
692  FloatAttr::get(elemType, 0.0).getValue()));
693  }
694  }
695 
696  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
697 }
698 
699 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
700  OpBuilder &builder) {
701  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
702  unsigned width = intType.getWidth();
703  if (width == 1)
704  return builder.create<spirv::ConstantOp>(loc, type,
705  builder.getBoolAttr(true));
706  return builder.create<spirv::ConstantOp>(
707  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
708  }
709  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
710  return builder.create<spirv::ConstantOp>(
711  loc, type, builder.getFloatAttr(floatType, 1.0));
712  }
713  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
714  Type elemType = vectorType.getElementType();
715  if (llvm::isa<IntegerType>(elemType)) {
716  return builder.create<spirv::ConstantOp>(
717  loc, type,
718  DenseElementsAttr::get(vectorType,
719  IntegerAttr::get(elemType, 1).getValue()));
720  }
721  if (llvm::isa<FloatType>(elemType)) {
722  return builder.create<spirv::ConstantOp>(
723  loc, type,
724  DenseFPElementsAttr::get(vectorType,
725  FloatAttr::get(elemType, 1.0).getValue()));
726  }
727  }
728 
729  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
730 }
731 
732 void mlir::spirv::ConstantOp::getAsmResultNames(
733  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
734  Type type = getType();
735 
736  SmallString<32> specialNameBuffer;
737  llvm::raw_svector_ostream specialName(specialNameBuffer);
738  specialName << "cst";
739 
740  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
741 
742  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
743  if (intTy && intTy.getWidth() == 1) {
744  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
745  }
746 
747  if (intTy.isSignless()) {
748  specialName << intCst.getInt();
749  } else if (intTy.isUnsigned()) {
750  specialName << intCst.getUInt();
751  } else {
752  specialName << intCst.getSInt();
753  }
754  }
755 
756  if (intTy || llvm::isa<FloatType>(type)) {
757  specialName << '_' << type;
758  }
759 
760  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
761  specialName << "_vec_";
762  specialName << vecType.getDimSize(0);
763 
764  Type elementType = vecType.getElementType();
765 
766  if (llvm::isa<IntegerType>(elementType) ||
767  llvm::isa<FloatType>(elementType)) {
768  specialName << "x" << elementType;
769  }
770  }
771 
772  setNameFn(getResult(), specialName.str());
773 }
774 
775 void mlir::spirv::AddressOfOp::getAsmResultNames(
776  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
777  SmallString<32> specialNameBuffer;
778  llvm::raw_svector_ostream specialName(specialNameBuffer);
779  specialName << getVariable() << "_addr";
780  setNameFn(getResult(), specialName.str());
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // spirv.ControlBarrierOp
785 //===----------------------------------------------------------------------===//
786 
788  return verifyMemorySemantics(getOperation(), getMemorySemantics());
789 }
790 
791 //===----------------------------------------------------------------------===//
792 // spirv.EntryPoint
793 //===----------------------------------------------------------------------===//
794 
795 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
796  spirv::ExecutionModel executionModel,
797  spirv::FuncOp function,
798  ArrayRef<Attribute> interfaceVars) {
799  build(builder, state,
800  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
801  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
802 }
803 
805  OperationState &result) {
806  spirv::ExecutionModel execModel;
808  SmallVector<Type, 0> idTypes;
809  SmallVector<Attribute, 4> interfaceVars;
810 
812  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
813  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
814  return failure();
815  }
816 
817  if (!parser.parseOptionalComma()) {
818  // Parse the interface variables
819  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
820  // The name of the interface variable attribute isnt important
821  FlatSymbolRefAttr var;
822  NamedAttrList attrs;
823  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
824  return failure();
825  interfaceVars.push_back(var);
826  return success();
827  }))
828  return failure();
829  }
830  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
831  parser.getBuilder().getArrayAttr(interfaceVars));
832  return success();
833 }
834 
836  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
837  printer.printSymbolName(getFn());
838  auto interfaceVars = getInterface().getValue();
839  if (!interfaceVars.empty()) {
840  printer << ", ";
841  llvm::interleaveComma(interfaceVars, printer);
842  }
843 }
844 
846  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
847  // verification.
848  return success();
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // spirv.ExecutionMode
853 //===----------------------------------------------------------------------===//
854 
855 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
856  spirv::FuncOp function,
857  spirv::ExecutionMode executionMode,
858  ArrayRef<int32_t> params) {
859  build(builder, state, SymbolRefAttr::get(function),
860  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
861  builder.getI32ArrayAttr(params));
862 }
863 
865  OperationState &result) {
866  spirv::ExecutionMode execMode;
867  Attribute fn;
868  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
869  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
870  return failure();
871  }
872 
874  Type i32Type = parser.getBuilder().getIntegerType(32);
875  while (!parser.parseOptionalComma()) {
876  NamedAttrList attr;
877  Attribute value;
878  if (parser.parseAttribute(value, i32Type, "value", attr)) {
879  return failure();
880  }
881  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
882  }
883  StringRef valuesAttrName =
884  spirv::ExecutionModeOp::getValuesAttrName(result.name);
885  result.addAttribute(valuesAttrName,
886  parser.getBuilder().getI32ArrayAttr(values));
887  return success();
888 }
889 
891  printer << " ";
892  printer.printSymbolName(getFn());
893  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
894  auto values = this->getValues();
895  if (values.empty())
896  return;
897  printer << ", ";
898  llvm::interleaveComma(values, printer, [&](Attribute a) {
899  printer << llvm::cast<IntegerAttr>(a).getInt();
900  });
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // spirv.func
905 //===----------------------------------------------------------------------===//
906 
909  SmallVector<DictionaryAttr> resultAttrs;
910  SmallVector<Type> resultTypes;
911  auto &builder = parser.getBuilder();
912 
913  // Parse the name as a symbol.
914  StringAttr nameAttr;
915  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
916  result.attributes))
917  return failure();
918 
919  // Parse the function signature.
920  bool isVariadic = false;
922  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
923  resultAttrs))
924  return failure();
925 
926  SmallVector<Type> argTypes;
927  for (auto &arg : entryArgs)
928  argTypes.push_back(arg.type);
929  auto fnType = builder.getFunctionType(argTypes, resultTypes);
930  result.addAttribute(getFunctionTypeAttrName(result.name),
931  TypeAttr::get(fnType));
932 
933  // Parse the optional function control keyword.
934  spirv::FunctionControl fnControl;
935  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
936  return failure();
937 
938  // If additional attributes are present, parse them.
939  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
940  return failure();
941 
942  // Add the attributes to the function arguments.
943  assert(resultAttrs.size() == resultTypes.size());
945  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
946  getResAttrsAttrName(result.name));
947 
948  // Parse the optional function body.
949  auto *body = result.addRegion();
950  OptionalParseResult parseResult =
951  parser.parseOptionalRegion(*body, entryArgs);
952  return failure(parseResult.has_value() && failed(*parseResult));
953 }
954 
955 void spirv::FuncOp::print(OpAsmPrinter &printer) {
956  // Print function name, signature, and control.
957  printer << " ";
958  printer.printSymbolName(getSymName());
959  auto fnType = getFunctionType();
961  printer, *this, fnType.getInputs(),
962  /*isVariadic=*/false, fnType.getResults());
963  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
964  << "\"";
966  printer, *this,
967  {spirv::attributeName<spirv::FunctionControl>(),
968  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
969  getFunctionControlAttrName()});
970 
971  // Print the body if this is not an external function.
972  Region &body = this->getBody();
973  if (!body.empty()) {
974  printer << ' ';
975  printer.printRegion(body, /*printEntryBlockArgs=*/false,
976  /*printBlockTerminators=*/true);
977  }
978 }
979 
980 LogicalResult spirv::FuncOp::verifyType() {
981  FunctionType fnType = getFunctionType();
982  if (fnType.getNumResults() > 1)
983  return emitOpError("cannot have more than one result");
984 
985  auto hasDecorationAttr = [&](spirv::Decoration decoration,
986  unsigned argIndex) {
987  auto func = llvm::cast<FunctionOpInterface>(getOperation());
988  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
989  if (argAttr.getName() != spirv::DecorationAttr::name)
990  continue;
991  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
992  return decAttr.getValue() == decoration;
993  }
994  return false;
995  };
996 
997  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
998  Type param = fnType.getInputs()[i];
999  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1000  if (!inputPtrType)
1001  continue;
1002 
1003  auto pointeePtrType =
1004  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1005  if (pointeePtrType) {
1006  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1007  // > If an OpFunctionParameter is a pointer (or contains a pointer)
1008  // > and the type it points to is a pointer in the PhysicalStorageBuffer
1009  // > storage class, the function parameter must be decorated with exactly
1010  // > one of AliasedPointer or RestrictPointer.
1011  if (pointeePtrType.getStorageClass() !=
1012  spirv::StorageClass::PhysicalStorageBuffer)
1013  continue;
1014 
1015  bool hasAliasedPtr =
1016  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1017  bool hasRestrictPtr =
1018  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1019  if (!hasAliasedPtr && !hasRestrictPtr)
1020  return emitOpError()
1021  << "with a pointer points to a physical buffer pointer must "
1022  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1023  continue;
1024  }
1025  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1026  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1027  // > the PhysicalStorageBuffer storage class, the function parameter must
1028  // > be decorated with exactly one of Aliased or Restrict.
1029  if (auto pointeeArrayType =
1030  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1031  pointeePtrType =
1032  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1033  } else {
1034  pointeePtrType = inputPtrType;
1035  }
1036 
1037  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1038  spirv::StorageClass::PhysicalStorageBuffer)
1039  continue;
1040 
1041  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1042  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1043  if (!hasAliased && !hasRestrict)
1044  return emitOpError() << "with physical buffer pointer must be decorated "
1045  "either 'Aliased' or 'Restrict'";
1046  }
1047 
1048  return success();
1049 }
1050 
1051 LogicalResult spirv::FuncOp::verifyBody() {
1052  FunctionType fnType = getFunctionType();
1053 
1054  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1055  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1056  if (fnType.getNumResults() != 0)
1057  return retOp.emitOpError("cannot be used in functions returning value");
1058  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1059  if (fnType.getNumResults() != 1)
1060  return retOp.emitOpError(
1061  "returns 1 value but enclosing function requires ")
1062  << fnType.getNumResults() << " results";
1063 
1064  auto retOperandType = retOp.getValue().getType();
1065  auto fnResultType = fnType.getResult(0);
1066  if (retOperandType != fnResultType)
1067  return retOp.emitOpError(" return value's type (")
1068  << retOperandType << ") mismatch with function's result type ("
1069  << fnResultType << ")";
1070  }
1071  return WalkResult::advance();
1072  });
1073 
1074  // TODO: verify other bits like linkage type.
1075 
1076  return failure(walkResult.wasInterrupted());
1077 }
1078 
1079 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1080  StringRef name, FunctionType type,
1081  spirv::FunctionControl control,
1082  ArrayRef<NamedAttribute> attrs) {
1083  state.addAttribute(SymbolTable::getSymbolAttrName(),
1084  builder.getStringAttr(name));
1085  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1086  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1087  builder.getAttr<spirv::FunctionControlAttr>(control));
1088  state.attributes.append(attrs.begin(), attrs.end());
1089  state.addRegion();
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // spirv.GLFClampOp
1094 //===----------------------------------------------------------------------===//
1095 
1097  OperationState &result) {
1098  return parseOneResultSameOperandTypeOp(parser, result);
1099 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // spirv.GLUClampOp
1104 //===----------------------------------------------------------------------===//
1105 
1107  OperationState &result) {
1108  return parseOneResultSameOperandTypeOp(parser, result);
1109 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // spirv.GLSClampOp
1114 //===----------------------------------------------------------------------===//
1115 
1117  OperationState &result) {
1118  return parseOneResultSameOperandTypeOp(parser, result);
1119 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // spirv.GLFmaOp
1124 //===----------------------------------------------------------------------===//
1125 
1127  return parseOneResultSameOperandTypeOp(parser, result);
1128 }
1129 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1130 
1131 //===----------------------------------------------------------------------===//
1132 // spirv.GlobalVariable
1133 //===----------------------------------------------------------------------===//
1134 
1135 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1136  Type type, StringRef name,
1137  unsigned descriptorSet, unsigned binding) {
1138  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1139  state.addAttribute(
1140  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1141  builder.getI32IntegerAttr(descriptorSet));
1142  state.addAttribute(
1143  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1144  builder.getI32IntegerAttr(binding));
1145 }
1146 
1147 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1148  Type type, StringRef name,
1149  spirv::BuiltIn builtin) {
1150  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1151  state.addAttribute(
1152  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1153  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1154 }
1155 
1157  OperationState &result) {
1158  // Parse variable name.
1159  StringAttr nameAttr;
1160  StringRef initializerAttrName =
1161  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1162  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1163  result.attributes)) {
1164  return failure();
1165  }
1166 
1167  // Parse optional initializer
1168  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1169  FlatSymbolRefAttr initSymbol;
1170  if (parser.parseLParen() ||
1171  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1172  result.attributes) ||
1173  parser.parseRParen())
1174  return failure();
1175  }
1176 
1177  if (parseVariableDecorations(parser, result)) {
1178  return failure();
1179  }
1180 
1181  Type type;
1182  StringRef typeAttrName =
1183  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1184  auto loc = parser.getCurrentLocation();
1185  if (parser.parseColonType(type)) {
1186  return failure();
1187  }
1188  if (!llvm::isa<spirv::PointerType>(type)) {
1189  return parser.emitError(loc, "expected spirv.ptr type");
1190  }
1191  result.addAttribute(typeAttrName, TypeAttr::get(type));
1192 
1193  return success();
1194 }
1195 
1197  SmallVector<StringRef, 4> elidedAttrs{
1198  spirv::attributeName<spirv::StorageClass>()};
1199 
1200  // Print variable name.
1201  printer << ' ';
1202  printer.printSymbolName(getSymName());
1203  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1204 
1205  StringRef initializerAttrName = this->getInitializerAttrName();
1206  // Print optional initializer
1207  if (auto initializer = this->getInitializer()) {
1208  printer << " " << initializerAttrName << '(';
1209  printer.printSymbolName(*initializer);
1210  printer << ')';
1211  elidedAttrs.push_back(initializerAttrName);
1212  }
1213 
1214  StringRef typeAttrName = this->getTypeAttrName();
1215  elidedAttrs.push_back(typeAttrName);
1216  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1217  printer << " : " << getType();
1218 }
1219 
1221  if (!llvm::isa<spirv::PointerType>(getType()))
1222  return emitOpError("result must be of a !spv.ptr type");
1223 
1224  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1225  // object. It cannot be Generic. It must be the same as the Storage Class
1226  // operand of the Result Type."
1227  // Also, Function storage class is reserved by spirv.Variable.
1228  auto storageClass = this->storageClass();
1229  if (storageClass == spirv::StorageClass::Generic ||
1230  storageClass == spirv::StorageClass::Function) {
1231  return emitOpError("storage class cannot be '")
1232  << stringifyStorageClass(storageClass) << "'";
1233  }
1234 
1235  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1236  this->getInitializerAttrName())) {
1238  (*this)->getParentOp(), init.getAttr());
1239  // TODO: Currently only variable initialization with specialization
1240  // constants and other variables is supported. They could be normal
1241  // constants in the module scope as well.
1242  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1243  spirv::SpecConstantCompositeOp>(initOp)) {
1244  return emitOpError("initializer must be result of a "
1245  "spirv.SpecConstant or spirv.GlobalVariable or "
1246  "spirv.SpecConstantCompositeOp op");
1247  }
1248  }
1249 
1250  return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // spirv.INTEL.SubgroupBlockRead
1255 //===----------------------------------------------------------------------===//
1256 
1258  OperationState &result) {
1259  // Parse the storage class specification
1260  spirv::StorageClass storageClass;
1262  Type elementType;
1263  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
1264  parser.parseColon() || parser.parseType(elementType)) {
1265  return failure();
1266  }
1267 
1268  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1269  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1270  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1271 
1272  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
1273  return failure();
1274  }
1275 
1276  result.addTypes(elementType);
1277  return success();
1278 }
1279 
1281  printer << " " << getPtr() << " : " << getType();
1282 }
1283 
1285  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1286  return failure();
1287 
1288  return success();
1289 }
1290 
1291 //===----------------------------------------------------------------------===//
1292 // spirv.INTEL.SubgroupBlockWrite
1293 //===----------------------------------------------------------------------===//
1294 
1296  OperationState &result) {
1297  // Parse the storage class specification
1298  spirv::StorageClass storageClass;
1300  auto loc = parser.getCurrentLocation();
1301  Type elementType;
1302  if (parseEnumStrAttr(storageClass, parser) ||
1303  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1304  parser.parseType(elementType)) {
1305  return failure();
1306  }
1307 
1308  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1309  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1310  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1311 
1312  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1313  result.operands)) {
1314  return failure();
1315  }
1316  return success();
1317 }
1318 
1320  printer << " " << getPtr() << ", " << getValue() << " : "
1321  << getValue().getType();
1322 }
1323 
1325  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1326  return failure();
1327 
1328  return success();
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // spirv.IAddCarryOp
1333 //===----------------------------------------------------------------------===//
1334 
1337 }
1338 
1340  OperationState &result) {
1342 }
1343 
1344 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1345  ::printArithmeticExtendedBinaryOp(*this, printer);
1346 }
1347 
1348 //===----------------------------------------------------------------------===//
1349 // spirv.ISubBorrowOp
1350 //===----------------------------------------------------------------------===//
1351 
1354 }
1355 
1357  OperationState &result) {
1359 }
1360 
1362  ::printArithmeticExtendedBinaryOp(*this, printer);
1363 }
1364 
1365 //===----------------------------------------------------------------------===//
1366 // spirv.SMulExtended
1367 //===----------------------------------------------------------------------===//
1368 
1371 }
1372 
1374  OperationState &result) {
1376 }
1377 
1379  ::printArithmeticExtendedBinaryOp(*this, printer);
1380 }
1381 
1382 //===----------------------------------------------------------------------===//
1383 // spirv.UMulExtended
1384 //===----------------------------------------------------------------------===//
1385 
1388 }
1389 
1391  OperationState &result) {
1393 }
1394 
1396  ::printArithmeticExtendedBinaryOp(*this, printer);
1397 }
1398 
1399 //===----------------------------------------------------------------------===//
1400 // spirv.MemoryBarrierOp
1401 //===----------------------------------------------------------------------===//
1402 
1404  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1405 }
1406 
1407 //===----------------------------------------------------------------------===//
1408 // spirv.module
1409 //===----------------------------------------------------------------------===//
1410 
1411 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1412  std::optional<StringRef> name) {
1413  OpBuilder::InsertionGuard guard(builder);
1414  builder.createBlock(state.addRegion());
1415  if (name) {
1416  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1417  builder.getStringAttr(*name));
1418  }
1419 }
1420 
1421 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1422  spirv::AddressingModel addressingModel,
1423  spirv::MemoryModel memoryModel,
1424  std::optional<VerCapExtAttr> vceTriple,
1425  std::optional<StringRef> name) {
1426  state.addAttribute(
1427  "addressing_model",
1428  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1429  state.addAttribute("memory_model",
1430  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1431  OpBuilder::InsertionGuard guard(builder);
1432  builder.createBlock(state.addRegion());
1433  if (vceTriple)
1434  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1435  if (name)
1436  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1437  builder.getStringAttr(*name));
1438 }
1439 
1441  OperationState &result) {
1442  Region *body = result.addRegion();
1443 
1444  // If the name is present, parse it.
1445  StringAttr nameAttr;
1446  (void)parser.parseOptionalSymbolName(
1447  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1448 
1449  // Parse attributes
1450  spirv::AddressingModel addrModel;
1451  spirv::MemoryModel memoryModel;
1452  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1453  result) ||
1454  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1455  result))
1456  return failure();
1457 
1458  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1459  spirv::VerCapExtAttr vceTriple;
1460  if (parser.parseAttribute(vceTriple,
1461  spirv::ModuleOp::getVCETripleAttrName(),
1462  result.attributes))
1463  return failure();
1464  }
1465 
1466  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1467  parser.parseRegion(*body, /*arguments=*/{}))
1468  return failure();
1469 
1470  // Make sure we have at least one block.
1471  if (body->empty())
1472  body->push_back(new Block());
1473 
1474  return success();
1475 }
1476 
1477 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1478  if (std::optional<StringRef> name = getName()) {
1479  printer << ' ';
1480  printer.printSymbolName(*name);
1481  }
1482 
1483  SmallVector<StringRef, 2> elidedAttrs;
1484 
1485  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1486  << spirv::stringifyMemoryModel(getMemoryModel());
1487  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1488  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1489  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1491 
1492  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1493  printer << " requires " << *triple;
1494  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1495  }
1496 
1497  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1498  printer << ' ';
1499  printer.printRegion(getRegion());
1500 }
1501 
1502 LogicalResult spirv::ModuleOp::verifyRegions() {
1503  Dialect *dialect = (*this)->getDialect();
1505  entryPoints;
1506  mlir::SymbolTable table(*this);
1507 
1508  for (auto &op : *getBody()) {
1509  if (op.getDialect() != dialect)
1510  return op.emitError("'spirv.module' can only contain spirv.* ops");
1511 
1512  // For EntryPoint op, check that the function and execution model is not
1513  // duplicated in EntryPointOps. Also verify that the interface specified
1514  // comes from globalVariables here to make this check cheaper.
1515  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1516  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1517  if (!funcOp) {
1518  return entryPointOp.emitError("function '")
1519  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1520  }
1521  if (auto interface = entryPointOp.getInterface()) {
1522  for (Attribute varRef : interface) {
1523  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1524  if (!varSymRef) {
1525  return entryPointOp.emitError(
1526  "expected symbol reference for interface "
1527  "specification instead of '")
1528  << varRef;
1529  }
1530  auto variableOp =
1531  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1532  if (!variableOp) {
1533  return entryPointOp.emitError("expected spirv.GlobalVariable "
1534  "symbol reference instead of'")
1535  << varSymRef << "'";
1536  }
1537  }
1538  }
1539 
1540  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1541  funcOp, entryPointOp.getExecutionModel());
1542  auto entryPtIt = entryPoints.find(key);
1543  if (entryPtIt != entryPoints.end()) {
1544  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1545  }
1546  entryPoints[key] = entryPointOp;
1547  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1548  // If the function is external and does not have 'Import'
1549  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1550  // LinkageAttributes is used to import external functions.
1551  auto linkageAttr = funcOp.getLinkageAttributes();
1552  auto hasImportLinkage =
1553  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1554  spirv::LinkageType::Import);
1555  if (funcOp.isExternal() && !hasImportLinkage)
1556  return op.emitError(
1557  "'spirv.module' cannot contain external functions "
1558  "without 'Import' linkage_attributes (LinkageAttributes)");
1559 
1560  // TODO: move this check to spirv.func.
1561  for (auto &block : funcOp)
1562  for (auto &op : block) {
1563  if (op.getDialect() != dialect)
1564  return op.emitError(
1565  "functions in 'spirv.module' can only contain spirv.* ops");
1566  }
1567  }
1568  }
1569 
1570  return success();
1571 }
1572 
1573 //===----------------------------------------------------------------------===//
1574 // spirv.mlir.referenceof
1575 //===----------------------------------------------------------------------===//
1576 
1578  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1579  (*this)->getParentOp(), getSpecConstAttr());
1580  Type constType;
1581 
1582  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1583  if (specConstOp)
1584  constType = specConstOp.getDefaultValue().getType();
1585 
1586  auto specConstCompositeOp =
1587  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1588  if (specConstCompositeOp)
1589  constType = specConstCompositeOp.getType();
1590 
1591  if (!specConstOp && !specConstCompositeOp)
1592  return emitOpError(
1593  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1594 
1595  if (getReference().getType() != constType)
1596  return emitOpError("result type mismatch with the referenced "
1597  "specialization constant's type");
1598 
1599  return success();
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // spirv.SpecConstant
1604 //===----------------------------------------------------------------------===//
1605 
1607  OperationState &result) {
1608  StringAttr nameAttr;
1609  Attribute valueAttr;
1610  StringRef defaultValueAttrName =
1611  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1612 
1613  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1614  result.attributes))
1615  return failure();
1616 
1617  // Parse optional spec_id.
1619  IntegerAttr specIdAttr;
1620  if (parser.parseLParen() ||
1621  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1622  parser.parseRParen())
1623  return failure();
1624  }
1625 
1626  if (parser.parseEqual() ||
1627  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1628  return failure();
1629 
1630  return success();
1631 }
1632 
1634  printer << ' ';
1635  printer.printSymbolName(getSymName());
1636  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1637  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1638  printer << " = " << getDefaultValue();
1639 }
1640 
1642  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1643  if (specID.getValue().isNegative())
1644  return emitOpError("SpecId cannot be negative");
1645 
1646  auto value = getDefaultValue();
1647  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1648  // Make sure bitwidth is allowed.
1649  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1650  return emitOpError("default value bitwidth disallowed");
1651  return success();
1652  }
1653  return emitOpError(
1654  "default value can only be a bool, integer, or float scalar");
1655 }
1656 
1657 //===----------------------------------------------------------------------===//
1658 // spirv.VectorShuffle
1659 //===----------------------------------------------------------------------===//
1660 
1662  VectorType resultType = llvm::cast<VectorType>(getType());
1663 
1664  size_t numResultElements = resultType.getNumElements();
1665  if (numResultElements != getComponents().size())
1666  return emitOpError("result type element count (")
1667  << numResultElements
1668  << ") mismatch with the number of component selectors ("
1669  << getComponents().size() << ")";
1670 
1671  size_t totalSrcElements =
1672  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1673  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1674 
1675  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1676  uint32_t index = selector.getZExtValue();
1677  if (index >= totalSrcElements &&
1678  index != std::numeric_limits<uint32_t>().max())
1679  return emitOpError("component selector ")
1680  << index << " out of range: expected to be in [0, "
1681  << totalSrcElements << ") or 0xffffffff";
1682  }
1683  return success();
1684 }
1685 
1686 //===----------------------------------------------------------------------===//
1687 // spirv.MatrixTimesScalar
1688 //===----------------------------------------------------------------------===//
1689 
1691  Type elementType =
1692  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1694  [](auto matrixType) { return matrixType.getElementType(); })
1695  .Default([](Type) { return nullptr; });
1696 
1697  assert(elementType && "Unhandled type");
1698 
1699  // Check that the scalar type is the same as the matrix element type.
1700  if (getScalar().getType() != elementType)
1701  return emitOpError("input matrix components' type and scaling value must "
1702  "have the same type");
1703 
1704  return success();
1705 }
1706 
1707 //===----------------------------------------------------------------------===//
1708 // spirv.Transpose
1709 //===----------------------------------------------------------------------===//
1710 
1712  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1713  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1714 
1715  // Verify that the input and output matrices have correct shapes.
1716  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1717  return emitError("input matrix rows count must be equal to "
1718  "output matrix columns count");
1719 
1720  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1721  return emitError("input matrix columns count must be equal to "
1722  "output matrix rows count");
1723 
1724  // Verify that the input and output matrices have the same component type
1725  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1726  return emitError("input and output matrices must have the same "
1727  "component type");
1728 
1729  return success();
1730 }
1731 
1732 //===----------------------------------------------------------------------===//
1733 // spirv.MatrixTimesMatrix
1734 //===----------------------------------------------------------------------===//
1735 
1737  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1738  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1739  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1740 
1741  // left matrix columns' count and right matrix rows' count must be equal
1742  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1743  return emitError("left matrix columns' count must be equal to "
1744  "the right matrix rows' count");
1745 
1746  // right and result matrices columns' count must be the same
1747  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1748  return emitError(
1749  "right and result matrices must have equal columns' count");
1750 
1751  // right and result matrices component type must be the same
1752  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1753  return emitError("right and result matrices' component type must"
1754  " be the same");
1755 
1756  // left and result matrices component type must be the same
1757  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1758  return emitError("left and result matrices' component type"
1759  " must be the same");
1760 
1761  // left and result matrices rows count must be the same
1762  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1763  return emitError("left and result matrices must have equal rows' count");
1764 
1765  return success();
1766 }
1767 
1768 //===----------------------------------------------------------------------===//
1769 // spirv.SpecConstantComposite
1770 //===----------------------------------------------------------------------===//
1771 
1773  OperationState &result) {
1774 
1775  StringAttr compositeName;
1776  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1777  result.attributes))
1778  return failure();
1779 
1780  if (parser.parseLParen())
1781  return failure();
1782 
1783  SmallVector<Attribute, 4> constituents;
1784 
1785  do {
1786  // The name of the constituent attribute isn't important
1787  const char *attrName = "spec_const";
1788  FlatSymbolRefAttr specConstRef;
1789  NamedAttrList attrs;
1790 
1791  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1792  return failure();
1793 
1794  constituents.push_back(specConstRef);
1795  } while (!parser.parseOptionalComma());
1796 
1797  if (parser.parseRParen())
1798  return failure();
1799 
1800  StringAttr compositeSpecConstituentsName =
1801  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1802  result.addAttribute(compositeSpecConstituentsName,
1803  parser.getBuilder().getArrayAttr(constituents));
1804 
1805  Type type;
1806  if (parser.parseColonType(type))
1807  return failure();
1808 
1809  StringAttr typeAttrName =
1810  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1811  result.addAttribute(typeAttrName, TypeAttr::get(type));
1812 
1813  return success();
1814 }
1815 
1817  printer << " ";
1818  printer.printSymbolName(getSymName());
1819  printer << " (";
1820  auto constituents = this->getConstituents().getValue();
1821 
1822  if (!constituents.empty())
1823  llvm::interleaveComma(constituents, printer);
1824 
1825  printer << ") : " << getType();
1826 }
1827 
1829  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1830  auto constituents = this->getConstituents().getValue();
1831 
1832  if (!cType)
1833  return emitError("result type must be a composite type, but provided ")
1834  << getType();
1835 
1836  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1837  return emitError("unsupported composite type ") << cType;
1838  if (llvm::isa<spirv::JointMatrixINTELType>(cType))
1839  return emitError("unsupported composite type ") << cType;
1840  if (constituents.size() != cType.getNumElements())
1841  return emitError("has incorrect number of operands: expected ")
1842  << cType.getNumElements() << ", but provided "
1843  << constituents.size();
1844 
1845  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1846  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1847 
1848  auto constituentSpecConstOp =
1849  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1850  (*this)->getParentOp(), constituent.getAttr()));
1851 
1852  if (constituentSpecConstOp.getDefaultValue().getType() !=
1853  cType.getElementType(index))
1854  return emitError("has incorrect types of operands: expected ")
1855  << cType.getElementType(index) << ", but provided "
1856  << constituentSpecConstOp.getDefaultValue().getType();
1857  }
1858 
1859  return success();
1860 }
1861 
1862 //===----------------------------------------------------------------------===//
1863 // spirv.SpecConstantOperation
1864 //===----------------------------------------------------------------------===//
1865 
1867  OperationState &result) {
1868  Region *body = result.addRegion();
1869 
1870  if (parser.parseKeyword("wraps"))
1871  return failure();
1872 
1873  body->push_back(new Block);
1874  Block &block = body->back();
1875  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1876 
1877  if (!wrappedOp)
1878  return failure();
1879 
1880  OpBuilder builder(parser.getContext());
1881  builder.setInsertionPointToEnd(&block);
1882  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1883  result.location = wrappedOp->getLoc();
1884 
1885  result.addTypes(wrappedOp->getResult(0).getType());
1886 
1887  if (parser.parseOptionalAttrDict(result.attributes))
1888  return failure();
1889 
1890  return success();
1891 }
1892 
1894  printer << " wraps ";
1895  printer.printGenericOp(&getBody().front().front());
1896 }
1897 
1898 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1899  Block &block = getRegion().getBlocks().front();
1900 
1901  if (block.getOperations().size() != 2)
1902  return emitOpError("expected exactly 2 nested ops");
1903 
1904  Operation &enclosedOp = block.getOperations().front();
1905 
1907  return emitOpError("invalid enclosed op");
1908 
1909  for (auto operand : enclosedOp.getOperands())
1910  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1911  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1912  return emitOpError(
1913  "invalid operand, must be defined by a constant operation");
1914 
1915  return success();
1916 }
1917 
1918 //===----------------------------------------------------------------------===//
1919 // spirv.GL.FrexpStruct
1920 //===----------------------------------------------------------------------===//
1921 
1923  spirv::StructType structTy =
1924  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1925 
1926  if (structTy.getNumElements() != 2)
1927  return emitError("result type must be a struct type with two memebers");
1928 
1929  Type significandTy = structTy.getElementType(0);
1930  Type exponentTy = structTy.getElementType(1);
1931  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1932  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1933 
1934  Type operandTy = getOperand().getType();
1935  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1936  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1937 
1938  if (significandTy != operandTy)
1939  return emitError("member zero of the resulting struct type must be the "
1940  "same type as the operand");
1941 
1942  if (exponentVecTy) {
1943  IntegerType componentIntTy =
1944  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1945  if (!componentIntTy || componentIntTy.getWidth() != 32)
1946  return emitError("member one of the resulting struct type must"
1947  "be a scalar or vector of 32 bit integer type");
1948  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1949  return emitError("member one of the resulting struct type "
1950  "must be a scalar or vector of 32 bit integer type");
1951  }
1952 
1953  // Check that the two member types have the same number of components
1954  if (operandVecTy && exponentVecTy &&
1955  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1956  return success();
1957 
1958  if (operandFTy && exponentIntTy)
1959  return success();
1960 
1961  return emitError("member one of the resulting struct type must have the same "
1962  "number of components as the operand type");
1963 }
1964 
1965 //===----------------------------------------------------------------------===//
1966 // spirv.GL.Ldexp
1967 //===----------------------------------------------------------------------===//
1968 
1970  Type significandType = getX().getType();
1971  Type exponentType = getExp().getType();
1972 
1973  if (llvm::isa<FloatType>(significandType) !=
1974  llvm::isa<IntegerType>(exponentType))
1975  return emitOpError("operands must both be scalars or vectors");
1976 
1977  auto getNumElements = [](Type type) -> unsigned {
1978  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1979  return vectorType.getNumElements();
1980  return 1;
1981  };
1982 
1983  if (getNumElements(significandType) != getNumElements(exponentType))
1984  return emitOpError("operands must have the same number of elements");
1985 
1986  return success();
1987 }
1988 
1989 //===----------------------------------------------------------------------===//
1990 // spirv.ImageDrefGather
1991 //===----------------------------------------------------------------------===//
1992 
1994  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
1995  auto sampledImageType =
1996  llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
1997  auto imageType =
1998  llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
1999 
2000  if (resultType.getNumElements() != 4)
2001  return emitOpError("result type must be a vector of four components");
2002 
2003  Type elementType = resultType.getElementType();
2004  Type sampledElementType = imageType.getElementType();
2005  if (!llvm::isa<NoneType>(sampledElementType) &&
2006  elementType != sampledElementType)
2007  return emitOpError(
2008  "the component type of result must be the same as sampled type of the "
2009  "underlying image type");
2010 
2011  spirv::Dim imageDim = imageType.getDim();
2012  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
2013 
2014  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
2015  imageDim != spirv::Dim::Rect)
2016  return emitOpError(
2017  "the Dim operand of the underlying image type must be 2D, Cube, or "
2018  "Rect");
2019 
2020  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
2021  return emitOpError("the MS operand of the underlying image type must be 0");
2022 
2023  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
2024  auto operandArguments = getOperandArguments();
2025 
2026  return verifyImageOperands(*this, attr, operandArguments);
2027 }
2028 
2029 //===----------------------------------------------------------------------===//
2030 // spirv.ShiftLeftLogicalOp
2031 //===----------------------------------------------------------------------===//
2032 
2034  return verifyShiftOp(*this);
2035 }
2036 
2037 //===----------------------------------------------------------------------===//
2038 // spirv.ShiftRightArithmeticOp
2039 //===----------------------------------------------------------------------===//
2040 
2042  return verifyShiftOp(*this);
2043 }
2044 
2045 //===----------------------------------------------------------------------===//
2046 // spirv.ShiftRightLogicalOp
2047 //===----------------------------------------------------------------------===//
2048 
2050  return verifyShiftOp(*this);
2051 }
2052 
2053 //===----------------------------------------------------------------------===//
2054 // spirv.ImageQuerySize
2055 //===----------------------------------------------------------------------===//
2056 
2058  spirv::ImageType imageType =
2059  llvm::cast<spirv::ImageType>(getImage().getType());
2060  Type resultType = getResult().getType();
2061 
2062  spirv::Dim dim = imageType.getDim();
2063  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
2064  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
2065  switch (dim) {
2066  case spirv::Dim::Dim1D:
2067  case spirv::Dim::Dim2D:
2068  case spirv::Dim::Dim3D:
2069  case spirv::Dim::Cube:
2070  if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2071  samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2072  samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2073  return emitError(
2074  "if Dim is 1D, 2D, 3D, or Cube, "
2075  "it must also have either an MS of 1 or a Sampled of 0 or 2");
2076  break;
2077  case spirv::Dim::Buffer:
2078  case spirv::Dim::Rect:
2079  break;
2080  default:
2081  return emitError("the Dim operand of the image type must "
2082  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2083  }
2084 
2085  unsigned componentNumber = 0;
2086  switch (dim) {
2087  case spirv::Dim::Dim1D:
2088  case spirv::Dim::Buffer:
2089  componentNumber = 1;
2090  break;
2091  case spirv::Dim::Dim2D:
2092  case spirv::Dim::Cube:
2093  case spirv::Dim::Rect:
2094  componentNumber = 2;
2095  break;
2096  case spirv::Dim::Dim3D:
2097  componentNumber = 3;
2098  break;
2099  default:
2100  break;
2101  }
2102 
2103  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2104  componentNumber += 1;
2105 
2106  unsigned resultComponentNumber = 1;
2107  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2108  resultComponentNumber = resultVectorType.getNumElements();
2109 
2110  if (componentNumber != resultComponentNumber)
2111  return emitError("expected the result to have ")
2112  << componentNumber << " component(s), but found "
2113  << resultComponentNumber << " component(s)";
2114 
2115  return success();
2116 }
2117 
2118 //===----------------------------------------------------------------------===//
2119 // spirv.VectorTimesScalarOp
2120 //===----------------------------------------------------------------------===//
2121 
2123  if (getVector().getType() != getType())
2124  return emitOpError("vector operand and result type mismatch");
2125  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2126  if (getScalar().getType() != scalarType)
2127  return emitOpError("scalar operand and result element type match");
2128  return success();
2129 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:297
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:587
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:122
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:283
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:329
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:199
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:151
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:171
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:321
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:283
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:824
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:426
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:434
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:430
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
SPIR-V struct type.
Definition: SPIRVTypes.h:293
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:71
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.