MLIR  18.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/TypeUtilities.h"
33 #include "llvm/ADT/APFloat.h"
34 #include "llvm/ADT/APInt.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include <cassert>
40 #include <numeric>
41 #include <optional>
42 #include <type_traits>
43 
44 using namespace mlir;
45 using namespace mlir::spirv::AttrNames;
46 
47 //===----------------------------------------------------------------------===//
48 // Common utility functions
49 //===----------------------------------------------------------------------===//
50 
52  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
53  if (!constOp) {
54  return failure();
55  }
56  auto valueAttr = constOp.getValue();
57  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58  if (!integerValueAttr) {
59  return failure();
60  }
61 
62  if (integerValueAttr.getType().isSignlessInteger())
63  value = integerValueAttr.getInt();
64  else
65  value = integerValueAttr.getSInt();
66 
67  return success();
68 }
69 
72  spirv::MemorySemantics memorySemantics) {
73  // According to the SPIR-V specification:
74  // "Despite being a mask and allowing multiple bits to be combined, it is
75  // invalid for more than one of these four bits to be set: Acquire, Release,
76  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
77  // Release semantics is done by setting the AcquireRelease bit, not by setting
78  // two bits."
79  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80  spirv::MemorySemantics::Release |
81  spirv::MemorySemantics::AcquireRelease |
82  spirv::MemorySemantics::SequentiallyConsistent;
83 
84  auto bitCount =
85  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86  if (bitCount > 1) {
87  return op->emitError(
88  "expected at most one of these four memory constraints "
89  "to be set: `Acquire`, `Release`,"
90  "`AcquireRelease` or `SequentiallyConsistent`");
91  }
92  return success();
93 }
94 
96  SmallVectorImpl<StringRef> &elidedAttrs) {
97  // Print optional descriptor binding
98  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::DescriptorSet));
100  auto bindingName = llvm::convertToSnakeFromCamelCase(
101  stringifyDecoration(spirv::Decoration::Binding));
102  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
103  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
104  if (descriptorSet && binding) {
105  elidedAttrs.push_back(descriptorSetName);
106  elidedAttrs.push_back(bindingName);
107  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108  << ")";
109  }
110 
111  // Print BuiltIn attribute if present
112  auto builtInName = llvm::convertToSnakeFromCamelCase(
113  stringifyDecoration(spirv::Decoration::BuiltIn));
114  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
115  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116  elidedAttrs.push_back(builtInName);
117  }
118 
119  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
120 }
121 
123  OperationState &result) {
125  Type type;
126  // If the operand list is in-between parentheses, then we have a generic form.
127  // (see the fallback in `printOneResultOp`).
128  SMLoc loc = parser.getCurrentLocation();
129  if (!parser.parseOptionalLParen()) {
130  if (parser.parseOperandList(ops) || parser.parseRParen() ||
131  parser.parseOptionalAttrDict(result.attributes) ||
132  parser.parseColon() || parser.parseType(type))
133  return failure();
134  auto fnType = llvm::dyn_cast<FunctionType>(type);
135  if (!fnType) {
136  parser.emitError(loc, "expected function type");
137  return failure();
138  }
139  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
140  return failure();
141  result.addTypes(fnType.getResults());
142  return success();
143  }
144  return failure(parser.parseOperandList(ops) ||
145  parser.parseOptionalAttrDict(result.attributes) ||
146  parser.parseColonType(type) ||
147  parser.resolveOperands(ops, type, result.operands) ||
148  parser.addTypeToList(type, result.types));
149 }
150 
152  assert(op->getNumResults() == 1 && "op should have one result");
153 
154  // If not all the operand and result types are the same, just use the
155  // generic assembly form to avoid omitting information in printing.
156  auto resultType = op->getResult(0).getType();
157  if (llvm::any_of(op->getOperandTypes(),
158  [&](Type type) { return type != resultType; })) {
159  p.printGenericOp(op, /*printOpName=*/false);
160  return;
161  }
162 
163  p << ' ';
164  p.printOperands(op->getOperands());
166  // Now we can output only one type for all operands and the result.
167  p << " : " << resultType;
168 }
169 
170 template <typename Op>
172  spirv::ImageOperandsAttr attr,
173  Operation::operand_range operands) {
174  if (!attr) {
175  if (operands.empty())
176  return success();
177 
178  return imageOp.emitError("the Image Operands should encode what operands "
179  "follow, as per Image Operands");
180  }
181 
182  // TODO: Add the validation rules for the following Image Operands.
183  spirv::ImageOperands noSupportOperands =
184  spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
185  spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
186  spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
187  spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
188  spirv::ImageOperands::MakeTexelAvailable |
189  spirv::ImageOperands::MakeTexelVisible |
190  spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
191 
192  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
193  llvm_unreachable("unimplemented operands of Image Operands");
194 
195  return success();
196 }
197 
198 template <typename BlockReadWriteOpTy>
200  Value ptr, Value val) {
201  auto valType = val.getType();
202  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
203  valType = valVecTy.getElementType();
204 
205  if (valType !=
206  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
207  return op.emitOpError("mismatch in result type and pointer type");
208  }
209  return success();
210 }
211 
212 /// Walks the given type hierarchy with the given indices, potentially down
213 /// to component granularity, to select an element type. Returns null type and
214 /// emits errors with the given loc on failure.
215 static Type
217  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
218  if (indices.empty()) {
219  emitErrorFn("expected at least one index for spirv.CompositeExtract");
220  return nullptr;
221  }
222 
223  for (auto index : indices) {
224  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
225  if (cType.hasCompileTimeKnownNumElements() &&
226  (index < 0 ||
227  static_cast<uint64_t>(index) >= cType.getNumElements())) {
228  emitErrorFn("index ") << index << " out of bounds for " << type;
229  return nullptr;
230  }
231  type = cType.getElementType(index);
232  } else {
233  emitErrorFn("cannot extract from non-composite type ")
234  << type << " with index " << index;
235  return nullptr;
236  }
237  }
238  return type;
239 }
240 
241 static Type
243  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
244  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
245  if (!indicesArrayAttr) {
246  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
247  return nullptr;
248  }
249  if (indicesArrayAttr.empty()) {
250  emitErrorFn("expected at least one index for spirv.CompositeExtract");
251  return nullptr;
252  }
253 
254  SmallVector<int32_t, 2> indexVals;
255  for (auto indexAttr : indicesArrayAttr) {
256  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
257  if (!indexIntAttr) {
258  emitErrorFn("expected an 32-bit integer for index, but found '")
259  << indexAttr << "'";
260  return nullptr;
261  }
262  indexVals.push_back(indexIntAttr.getInt());
263  }
264  return getElementType(type, indexVals, emitErrorFn);
265 }
266 
267 static Type getElementType(Type type, Attribute indices, Location loc) {
268  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
269  return ::mlir::emitError(loc, err);
270  };
271  return getElementType(type, indices, errorFn);
272 }
273 
274 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
275  SMLoc loc) {
276  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
277  return parser.emitError(loc, err);
278  };
279  return getElementType(type, indices, errorFn);
280 }
281 
282 template <typename ExtendedBinaryOp>
283 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
284  auto resultType = llvm::cast<spirv::StructType>(op.getType());
285  if (resultType.getNumElements() != 2)
286  return op.emitOpError("expected result struct type containing two members");
287 
288  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
289  resultType.getElementType(0),
290  resultType.getElementType(1)}))
291  return op.emitOpError(
292  "expected all operand types and struct member types are the same");
293 
294  return success();
295 }
296 
298  OperationState &result) {
300  if (parser.parseOptionalAttrDict(result.attributes) ||
301  parser.parseOperandList(operands) || parser.parseColon())
302  return failure();
303 
304  Type resultType;
305  SMLoc loc = parser.getCurrentLocation();
306  if (parser.parseType(resultType))
307  return failure();
308 
309  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
310  if (!structType || structType.getNumElements() != 2)
311  return parser.emitError(loc, "expected spirv.struct type with two members");
312 
313  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
314  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
315  return failure();
316 
317  result.addTypes(resultType);
318  return success();
319 }
320 
322  OpAsmPrinter &printer) {
323  printer << ' ';
324  printer.printOptionalAttrDict(op->getAttrs());
325  printer.printOperands(op->getOperands());
326  printer << " : " << op->getResultTypes().front();
327 }
328 
330  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
331  return op->emitError("expected the same type for the first operand and "
332  "result, but provided ")
333  << op->getOperand(0).getType() << " and "
334  << op->getResult(0).getType();
335  }
336  return success();
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // spirv.mlir.addressof
341 //===----------------------------------------------------------------------===//
342 
343 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
344  spirv::GlobalVariableOp var) {
345  build(builder, state, var.getType(), SymbolRefAttr::get(var));
346 }
347 
349  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
350  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
351  getVariableAttr()));
352  if (!varOp) {
353  return emitOpError("expected spirv.GlobalVariable symbol");
354  }
355  if (getPointer().getType() != varOp.getType()) {
356  return emitOpError(
357  "result type mismatch with the referenced global variable's type");
358  }
359  return success();
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // spirv.CompositeConstruct
364 //===----------------------------------------------------------------------===//
365 
367  operand_range constituents = this->getConstituents();
368 
369  // There are 4 cases with varying verification rules:
370  // 1. Cooperative Matrices (1 constituent)
371  // 2. Structs (1 constituent for each member)
372  // 3. Arrays (1 constituent for each array element)
373  // 4. Vectors (1 constituent (sub-)element for each vector element)
374 
375  auto coopElementType =
379  [](auto coopType) { return coopType.getElementType(); })
380  .Default([](Type) { return nullptr; });
381 
382  // Case 1. -- matrices.
383  if (coopElementType) {
384  if (constituents.size() != 1)
385  return emitOpError("has incorrect number of operands: expected ")
386  << "1, but provided " << constituents.size();
387  if (coopElementType != constituents.front().getType())
388  return emitOpError("operand type mismatch: expected operand type ")
389  << coopElementType << ", but provided "
390  << constituents.front().getType();
391  return success();
392  }
393 
394  // Case 2./3./4. -- number of constituents matches the number of elements.
395  auto cType = llvm::cast<spirv::CompositeType>(getType());
396  if (constituents.size() == cType.getNumElements()) {
397  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
398  if (constituents[index].getType() != cType.getElementType(index)) {
399  return emitOpError("operand type mismatch: expected operand type ")
400  << cType.getElementType(index) << ", but provided "
401  << constituents[index].getType();
402  }
403  }
404  return success();
405  }
406 
407  // Case 4. -- check that all constituents add up tp the expected vector type.
408  auto resultType = llvm::dyn_cast<VectorType>(cType);
409  if (!resultType)
410  return emitOpError(
411  "expected to return a vector or cooperative matrix when the number of "
412  "constituents is less than what the result needs");
413 
414  SmallVector<unsigned> sizes;
415  for (Value component : constituents) {
416  if (!llvm::isa<VectorType>(component.getType()) &&
417  !component.getType().isIntOrFloat())
418  return emitOpError("operand type mismatch: expected operand to have "
419  "a scalar or vector type, but provided ")
420  << component.getType();
421 
422  Type elementType = component.getType();
423  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
424  sizes.push_back(vectorType.getNumElements());
425  elementType = vectorType.getElementType();
426  } else {
427  sizes.push_back(1);
428  }
429 
430  if (elementType != resultType.getElementType())
431  return emitOpError("operand element type mismatch: expected to be ")
432  << resultType.getElementType() << ", but provided " << elementType;
433  }
434  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
435  if (totalCount != cType.getNumElements())
436  return emitOpError("has incorrect number of operands: expected ")
437  << cType.getNumElements() << ", but provided " << totalCount;
438  return success();
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // spirv.CompositeExtractOp
443 //===----------------------------------------------------------------------===//
444 
445 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
446  Value composite,
447  ArrayRef<int32_t> indices) {
448  auto indexAttr = builder.getI32ArrayAttr(indices);
449  auto elementType =
450  getElementType(composite.getType(), indexAttr, state.location);
451  if (!elementType) {
452  return;
453  }
454  build(builder, state, elementType, composite, indexAttr);
455 }
456 
458  OperationState &result) {
459  OpAsmParser::UnresolvedOperand compositeInfo;
460  Attribute indicesAttr;
461  Type compositeType;
462  SMLoc attrLocation;
463 
464  if (parser.parseOperand(compositeInfo) ||
465  parser.getCurrentLocation(&attrLocation) ||
466  parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
467  parser.parseColonType(compositeType) ||
468  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
469  return failure();
470  }
471 
472  Type resultType =
473  getElementType(compositeType, indicesAttr, parser, attrLocation);
474  if (!resultType) {
475  return failure();
476  }
477  result.addTypes(resultType);
478  return success();
479 }
480 
482  printer << ' ' << getComposite() << getIndices() << " : "
483  << getComposite().getType();
484 }
485 
487  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
488  auto resultType =
489  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
490  if (!resultType)
491  return failure();
492 
493  if (resultType != getType()) {
494  return emitOpError("invalid result type: expected ")
495  << resultType << " but provided " << getType();
496  }
497 
498  return success();
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // spirv.CompositeInsert
503 //===----------------------------------------------------------------------===//
504 
505 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
506  Value object, Value composite,
507  ArrayRef<int32_t> indices) {
508  auto indexAttr = builder.getI32ArrayAttr(indices);
509  build(builder, state, composite.getType(), object, composite, indexAttr);
510 }
511 
513  OperationState &result) {
515  Type objectType, compositeType;
516  Attribute indicesAttr;
517  auto loc = parser.getCurrentLocation();
518 
519  return failure(
520  parser.parseOperandList(operands, 2) ||
521  parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
522  parser.parseColonType(objectType) ||
523  parser.parseKeywordType("into", compositeType) ||
524  parser.resolveOperands(operands, {objectType, compositeType}, loc,
525  result.operands) ||
526  parser.addTypesToList(compositeType, result.types));
527 }
528 
530  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
531  auto objectType =
532  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
533  if (!objectType)
534  return failure();
535 
536  if (objectType != getObject().getType()) {
537  return emitOpError("object operand type should be ")
538  << objectType << ", but found " << getObject().getType();
539  }
540 
541  if (getComposite().getType() != getType()) {
542  return emitOpError("result type should be the same as "
543  "the composite type, but found ")
544  << getComposite().getType() << " vs " << getType();
545  }
546 
547  return success();
548 }
549 
551  printer << " " << getObject() << ", " << getComposite() << getIndices()
552  << " : " << getObject().getType() << " into "
553  << getComposite().getType();
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // spirv.Constant
558 //===----------------------------------------------------------------------===//
559 
561  OperationState &result) {
562  Attribute value;
563  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
564  return failure();
565 
566  Type type = NoneType::get(parser.getContext());
567  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
568  type = typedAttr.getType();
569  if (llvm::isa<NoneType, TensorType>(type)) {
570  if (parser.parseColonType(type))
571  return failure();
572  }
573 
574  return parser.addTypeToList(type, result.types);
575 }
576 
577 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
578  printer << ' ' << getValue();
579  if (llvm::isa<spirv::ArrayType>(getType()))
580  printer << " : " << getType();
581 }
582 
583 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
584  Type opType) {
585  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
586  auto valueType = llvm::cast<TypedAttr>(value).getType();
587  if (valueType != opType)
588  return op.emitOpError("result type (")
589  << opType << ") does not match value type (" << valueType << ")";
590  return success();
591  }
592  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
593  auto valueType = llvm::cast<TypedAttr>(value).getType();
594  if (valueType == opType)
595  return success();
596  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
597  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
598  if (!arrayType)
599  return op.emitOpError("result or element type (")
600  << opType << ") does not match value type (" << valueType
601  << "), must be the same or spirv.array";
602 
603  int numElements = arrayType.getNumElements();
604  auto opElemType = arrayType.getElementType();
605  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
606  numElements *= t.getNumElements();
607  opElemType = t.getElementType();
608  }
609  if (!opElemType.isIntOrFloat())
610  return op.emitOpError("only support nested array result type");
611 
612  auto valueElemType = shapedType.getElementType();
613  if (valueElemType != opElemType) {
614  return op.emitOpError("result element type (")
615  << opElemType << ") does not match value element type ("
616  << valueElemType << ")";
617  }
618 
619  if (numElements != shapedType.getNumElements()) {
620  return op.emitOpError("result number of elements (")
621  << numElements << ") does not match value number of elements ("
622  << shapedType.getNumElements() << ")";
623  }
624  return success();
625  }
626  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
627  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
628  if (!arrayType)
629  return op.emitOpError(
630  "must have spirv.array result type for array value");
631  Type elemType = arrayType.getElementType();
632  for (Attribute element : arrayAttr.getValue()) {
633  // Verify array elements recursively.
634  if (failed(verifyConstantType(op, element, elemType)))
635  return failure();
636  }
637  return success();
638  }
639  return op.emitOpError("cannot have attribute: ") << value;
640 }
641 
643  // ODS already generates checks to make sure the result type is valid. We just
644  // need to additionally check that the value's attribute type is consistent
645  // with the result type.
646  return verifyConstantType(*this, getValueAttr(), getType());
647 }
648 
649 bool spirv::ConstantOp::isBuildableWith(Type type) {
650  // Must be valid SPIR-V type first.
651  if (!llvm::isa<spirv::SPIRVType>(type))
652  return false;
653 
654  if (isa<SPIRVDialect>(type.getDialect())) {
655  // TODO: support constant struct
656  return llvm::isa<spirv::ArrayType>(type);
657  }
658 
659  return true;
660 }
661 
662 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
663  OpBuilder &builder) {
664  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
665  unsigned width = intType.getWidth();
666  if (width == 1)
667  return builder.create<spirv::ConstantOp>(loc, type,
668  builder.getBoolAttr(false));
669  return builder.create<spirv::ConstantOp>(
670  loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
671  }
672  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
673  return builder.create<spirv::ConstantOp>(
674  loc, type, builder.getFloatAttr(floatType, 0.0));
675  }
676  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
677  Type elemType = vectorType.getElementType();
678  if (llvm::isa<IntegerType>(elemType)) {
679  return builder.create<spirv::ConstantOp>(
680  loc, type,
681  DenseElementsAttr::get(vectorType,
682  IntegerAttr::get(elemType, 0).getValue()));
683  }
684  if (llvm::isa<FloatType>(elemType)) {
685  return builder.create<spirv::ConstantOp>(
686  loc, type,
687  DenseFPElementsAttr::get(vectorType,
688  FloatAttr::get(elemType, 0.0).getValue()));
689  }
690  }
691 
692  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
693 }
694 
695 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
696  OpBuilder &builder) {
697  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
698  unsigned width = intType.getWidth();
699  if (width == 1)
700  return builder.create<spirv::ConstantOp>(loc, type,
701  builder.getBoolAttr(true));
702  return builder.create<spirv::ConstantOp>(
703  loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
704  }
705  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
706  return builder.create<spirv::ConstantOp>(
707  loc, type, builder.getFloatAttr(floatType, 1.0));
708  }
709  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
710  Type elemType = vectorType.getElementType();
711  if (llvm::isa<IntegerType>(elemType)) {
712  return builder.create<spirv::ConstantOp>(
713  loc, type,
714  DenseElementsAttr::get(vectorType,
715  IntegerAttr::get(elemType, 1).getValue()));
716  }
717  if (llvm::isa<FloatType>(elemType)) {
718  return builder.create<spirv::ConstantOp>(
719  loc, type,
720  DenseFPElementsAttr::get(vectorType,
721  FloatAttr::get(elemType, 1.0).getValue()));
722  }
723  }
724 
725  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
726 }
727 
728 void mlir::spirv::ConstantOp::getAsmResultNames(
729  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
730  Type type = getType();
731 
732  SmallString<32> specialNameBuffer;
733  llvm::raw_svector_ostream specialName(specialNameBuffer);
734  specialName << "cst";
735 
736  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
737 
738  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
739  if (intTy && intTy.getWidth() == 1) {
740  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
741  }
742 
743  if (intTy.isSignless()) {
744  specialName << intCst.getInt();
745  } else if (intTy.isUnsigned()) {
746  specialName << intCst.getUInt();
747  } else {
748  specialName << intCst.getSInt();
749  }
750  }
751 
752  if (intTy || llvm::isa<FloatType>(type)) {
753  specialName << '_' << type;
754  }
755 
756  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
757  specialName << "_vec_";
758  specialName << vecType.getDimSize(0);
759 
760  Type elementType = vecType.getElementType();
761 
762  if (llvm::isa<IntegerType>(elementType) ||
763  llvm::isa<FloatType>(elementType)) {
764  specialName << "x" << elementType;
765  }
766  }
767 
768  setNameFn(getResult(), specialName.str());
769 }
770 
771 void mlir::spirv::AddressOfOp::getAsmResultNames(
772  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
773  SmallString<32> specialNameBuffer;
774  llvm::raw_svector_ostream specialName(specialNameBuffer);
775  specialName << getVariable() << "_addr";
776  setNameFn(getResult(), specialName.str());
777 }
778 
779 //===----------------------------------------------------------------------===//
780 // spirv.ControlBarrierOp
781 //===----------------------------------------------------------------------===//
782 
784  return verifyMemorySemantics(getOperation(), getMemorySemantics());
785 }
786 
787 //===----------------------------------------------------------------------===//
788 // spirv.EntryPoint
789 //===----------------------------------------------------------------------===//
790 
791 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
792  spirv::ExecutionModel executionModel,
793  spirv::FuncOp function,
794  ArrayRef<Attribute> interfaceVars) {
795  build(builder, state,
796  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
797  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
798 }
799 
801  OperationState &result) {
802  spirv::ExecutionModel execModel;
804  SmallVector<Type, 0> idTypes;
805  SmallVector<Attribute, 4> interfaceVars;
806 
808  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
809  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
810  return failure();
811  }
812 
813  if (!parser.parseOptionalComma()) {
814  // Parse the interface variables
815  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
816  // The name of the interface variable attribute isnt important
817  FlatSymbolRefAttr var;
818  NamedAttrList attrs;
819  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
820  return failure();
821  interfaceVars.push_back(var);
822  return success();
823  }))
824  return failure();
825  }
827  parser.getBuilder().getArrayAttr(interfaceVars));
828  return success();
829 }
830 
832  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
833  printer.printSymbolName(getFn());
834  auto interfaceVars = getInterface().getValue();
835  if (!interfaceVars.empty()) {
836  printer << ", ";
837  llvm::interleaveComma(interfaceVars, printer);
838  }
839 }
840 
842  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
843  // verification.
844  return success();
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // spirv.ExecutionMode
849 //===----------------------------------------------------------------------===//
850 
851 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
852  spirv::FuncOp function,
853  spirv::ExecutionMode executionMode,
854  ArrayRef<int32_t> params) {
855  build(builder, state, SymbolRefAttr::get(function),
856  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
857  builder.getI32ArrayAttr(params));
858 }
859 
861  OperationState &result) {
862  spirv::ExecutionMode execMode;
863  Attribute fn;
864  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
865  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
866  return failure();
867  }
868 
870  Type i32Type = parser.getBuilder().getIntegerType(32);
871  while (!parser.parseOptionalComma()) {
872  NamedAttrList attr;
873  Attribute value;
874  if (parser.parseAttribute(value, i32Type, "value", attr)) {
875  return failure();
876  }
877  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
878  }
880  parser.getBuilder().getI32ArrayAttr(values));
881  return success();
882 }
883 
885  printer << " ";
886  printer.printSymbolName(getFn());
887  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
888  auto values = this->getValues();
889  if (values.empty())
890  return;
891  printer << ", ";
892  llvm::interleaveComma(values, printer, [&](Attribute a) {
893  printer << llvm::cast<IntegerAttr>(a).getInt();
894  });
895 }
896 
897 //===----------------------------------------------------------------------===//
898 // spirv.func
899 //===----------------------------------------------------------------------===//
900 
903  SmallVector<DictionaryAttr> resultAttrs;
904  SmallVector<Type> resultTypes;
905  auto &builder = parser.getBuilder();
906 
907  // Parse the name as a symbol.
908  StringAttr nameAttr;
909  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
910  result.attributes))
911  return failure();
912 
913  // Parse the function signature.
914  bool isVariadic = false;
916  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
917  resultAttrs))
918  return failure();
919 
920  SmallVector<Type> argTypes;
921  for (auto &arg : entryArgs)
922  argTypes.push_back(arg.type);
923  auto fnType = builder.getFunctionType(argTypes, resultTypes);
924  result.addAttribute(getFunctionTypeAttrName(result.name),
925  TypeAttr::get(fnType));
926 
927  // Parse the optional function control keyword.
928  spirv::FunctionControl fnControl;
929  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
930  return failure();
931 
932  // If additional attributes are present, parse them.
933  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
934  return failure();
935 
936  // Add the attributes to the function arguments.
937  assert(resultAttrs.size() == resultTypes.size());
939  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
940  getResAttrsAttrName(result.name));
941 
942  // Parse the optional function body.
943  auto *body = result.addRegion();
944  OptionalParseResult parseResult =
945  parser.parseOptionalRegion(*body, entryArgs);
946  return failure(parseResult.has_value() && failed(*parseResult));
947 }
948 
949 void spirv::FuncOp::print(OpAsmPrinter &printer) {
950  // Print function name, signature, and control.
951  printer << " ";
952  printer.printSymbolName(getSymName());
953  auto fnType = getFunctionType();
955  printer, *this, fnType.getInputs(),
956  /*isVariadic=*/false, fnType.getResults());
957  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
958  << "\"";
960  printer, *this,
961  {spirv::attributeName<spirv::FunctionControl>(),
962  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
963  getFunctionControlAttrName()});
964 
965  // Print the body if this is not an external function.
966  Region &body = this->getBody();
967  if (!body.empty()) {
968  printer << ' ';
969  printer.printRegion(body, /*printEntryBlockArgs=*/false,
970  /*printBlockTerminators=*/true);
971  }
972 }
973 
974 LogicalResult spirv::FuncOp::verifyType() {
975  if (getFunctionType().getNumResults() > 1)
976  return emitOpError("cannot have more than one result");
977  return success();
978 }
979 
980 LogicalResult spirv::FuncOp::verifyBody() {
981  FunctionType fnType = getFunctionType();
982 
983  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
984  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
985  if (fnType.getNumResults() != 0)
986  return retOp.emitOpError("cannot be used in functions returning value");
987  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
988  if (fnType.getNumResults() != 1)
989  return retOp.emitOpError(
990  "returns 1 value but enclosing function requires ")
991  << fnType.getNumResults() << " results";
992 
993  auto retOperandType = retOp.getValue().getType();
994  auto fnResultType = fnType.getResult(0);
995  if (retOperandType != fnResultType)
996  return retOp.emitOpError(" return value's type (")
997  << retOperandType << ") mismatch with function's result type ("
998  << fnResultType << ")";
999  }
1000  return WalkResult::advance();
1001  });
1002 
1003  // TODO: verify other bits like linkage type.
1004 
1005  return failure(walkResult.wasInterrupted());
1006 }
1007 
1008 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1009  StringRef name, FunctionType type,
1010  spirv::FunctionControl control,
1011  ArrayRef<NamedAttribute> attrs) {
1012  state.addAttribute(SymbolTable::getSymbolAttrName(),
1013  builder.getStringAttr(name));
1014  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1015  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1016  builder.getAttr<spirv::FunctionControlAttr>(control));
1017  state.attributes.append(attrs.begin(), attrs.end());
1018  state.addRegion();
1019 }
1020 
1021 //===----------------------------------------------------------------------===//
1022 // spirv.GLFClampOp
1023 //===----------------------------------------------------------------------===//
1024 
1026  OperationState &result) {
1027  return parseOneResultSameOperandTypeOp(parser, result);
1028 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // spirv.GLUClampOp
1033 //===----------------------------------------------------------------------===//
1034 
1036  OperationState &result) {
1037  return parseOneResultSameOperandTypeOp(parser, result);
1038 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // spirv.GLSClampOp
1043 //===----------------------------------------------------------------------===//
1044 
1046  OperationState &result) {
1047  return parseOneResultSameOperandTypeOp(parser, result);
1048 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // spirv.GLFmaOp
1053 //===----------------------------------------------------------------------===//
1054 
1056  return parseOneResultSameOperandTypeOp(parser, result);
1057 }
1058 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // spirv.GlobalVariable
1062 //===----------------------------------------------------------------------===//
1063 
1064 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1065  Type type, StringRef name,
1066  unsigned descriptorSet, unsigned binding) {
1067  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1068  state.addAttribute(
1069  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1070  builder.getI32IntegerAttr(descriptorSet));
1071  state.addAttribute(
1072  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1073  builder.getI32IntegerAttr(binding));
1074 }
1075 
1076 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1077  Type type, StringRef name,
1078  spirv::BuiltIn builtin) {
1079  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1080  state.addAttribute(
1081  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1082  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1083 }
1084 
1086  OperationState &result) {
1087  // Parse variable name.
1088  StringAttr nameAttr;
1089  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1090  result.attributes)) {
1091  return failure();
1092  }
1093 
1094  // Parse optional initializer
1096  FlatSymbolRefAttr initSymbol;
1097  if (parser.parseLParen() ||
1098  parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
1099  result.attributes) ||
1100  parser.parseRParen())
1101  return failure();
1102  }
1103 
1104  if (parseVariableDecorations(parser, result)) {
1105  return failure();
1106  }
1107 
1108  Type type;
1109  auto loc = parser.getCurrentLocation();
1110  if (parser.parseColonType(type)) {
1111  return failure();
1112  }
1113  if (!llvm::isa<spirv::PointerType>(type)) {
1114  return parser.emitError(loc, "expected spirv.ptr type");
1115  }
1116  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
1117 
1118  return success();
1119 }
1120 
1122  SmallVector<StringRef, 4> elidedAttrs{
1123  spirv::attributeName<spirv::StorageClass>()};
1124 
1125  // Print variable name.
1126  printer << ' ';
1127  printer.printSymbolName(getSymName());
1128  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1129 
1130  // Print optional initializer
1131  if (auto initializer = this->getInitializer()) {
1132  printer << " " << kInitializerAttrName << '(';
1133  printer.printSymbolName(*initializer);
1134  printer << ')';
1135  elidedAttrs.push_back(kInitializerAttrName);
1136  }
1137 
1138  elidedAttrs.push_back(kTypeAttrName);
1139  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1140  printer << " : " << getType();
1141 }
1142 
1144  if (!llvm::isa<spirv::PointerType>(getType()))
1145  return emitOpError("result must be of a !spv.ptr type");
1146 
1147  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1148  // object. It cannot be Generic. It must be the same as the Storage Class
1149  // operand of the Result Type."
1150  // Also, Function storage class is reserved by spirv.Variable.
1151  auto storageClass = this->storageClass();
1152  if (storageClass == spirv::StorageClass::Generic ||
1153  storageClass == spirv::StorageClass::Function) {
1154  return emitOpError("storage class cannot be '")
1155  << stringifyStorageClass(storageClass) << "'";
1156  }
1157 
1158  if (auto init =
1159  (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
1161  (*this)->getParentOp(), init.getAttr());
1162  // TODO: Currently only variable initialization with specialization
1163  // constants and other variables is supported. They could be normal
1164  // constants in the module scope as well.
1165  if (!initOp ||
1166  !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
1167  return emitOpError("initializer must be result of a "
1168  "spirv.SpecConstant or spirv.GlobalVariable op");
1169  }
1170  }
1171 
1172  return success();
1173 }
1174 
1175 //===----------------------------------------------------------------------===//
1176 // spirv.INTEL.SubgroupBlockRead
1177 //===----------------------------------------------------------------------===//
1178 
1180  OperationState &result) {
1181  // Parse the storage class specification
1182  spirv::StorageClass storageClass;
1184  Type elementType;
1185  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
1186  parser.parseColon() || parser.parseType(elementType)) {
1187  return failure();
1188  }
1189 
1190  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1191  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1192  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1193 
1194  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
1195  return failure();
1196  }
1197 
1198  result.addTypes(elementType);
1199  return success();
1200 }
1201 
1203  printer << " " << getPtr() << " : " << getType();
1204 }
1205 
1207  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1208  return failure();
1209 
1210  return success();
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // spirv.INTEL.SubgroupBlockWrite
1215 //===----------------------------------------------------------------------===//
1216 
1218  OperationState &result) {
1219  // Parse the storage class specification
1220  spirv::StorageClass storageClass;
1222  auto loc = parser.getCurrentLocation();
1223  Type elementType;
1224  if (parseEnumStrAttr(storageClass, parser) ||
1225  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1226  parser.parseType(elementType)) {
1227  return failure();
1228  }
1229 
1230  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1231  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1232  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1233 
1234  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1235  result.operands)) {
1236  return failure();
1237  }
1238  return success();
1239 }
1240 
1242  printer << " " << getPtr() << ", " << getValue() << " : "
1243  << getValue().getType();
1244 }
1245 
1247  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1248  return failure();
1249 
1250  return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // spirv.IAddCarryOp
1255 //===----------------------------------------------------------------------===//
1256 
1259 }
1260 
1262  OperationState &result) {
1264 }
1265 
1266 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1267  ::printArithmeticExtendedBinaryOp(*this, printer);
1268 }
1269 
1270 //===----------------------------------------------------------------------===//
1271 // spirv.ISubBorrowOp
1272 //===----------------------------------------------------------------------===//
1273 
1276 }
1277 
1279  OperationState &result) {
1281 }
1282 
1284  ::printArithmeticExtendedBinaryOp(*this, printer);
1285 }
1286 
1287 //===----------------------------------------------------------------------===//
1288 // spirv.SMulExtended
1289 //===----------------------------------------------------------------------===//
1290 
1293 }
1294 
1296  OperationState &result) {
1298 }
1299 
1301  ::printArithmeticExtendedBinaryOp(*this, printer);
1302 }
1303 
1304 //===----------------------------------------------------------------------===//
1305 // spirv.UMulExtended
1306 //===----------------------------------------------------------------------===//
1307 
1310 }
1311 
1313  OperationState &result) {
1315 }
1316 
1318  ::printArithmeticExtendedBinaryOp(*this, printer);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // spirv.MemoryBarrierOp
1323 //===----------------------------------------------------------------------===//
1324 
1326  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // spirv.module
1331 //===----------------------------------------------------------------------===//
1332 
1333 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1334  std::optional<StringRef> name) {
1335  OpBuilder::InsertionGuard guard(builder);
1336  builder.createBlock(state.addRegion());
1337  if (name) {
1338  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1339  builder.getStringAttr(*name));
1340  }
1341 }
1342 
1343 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1344  spirv::AddressingModel addressingModel,
1345  spirv::MemoryModel memoryModel,
1346  std::optional<VerCapExtAttr> vceTriple,
1347  std::optional<StringRef> name) {
1348  state.addAttribute(
1349  "addressing_model",
1350  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1351  state.addAttribute("memory_model",
1352  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1353  OpBuilder::InsertionGuard guard(builder);
1354  builder.createBlock(state.addRegion());
1355  if (vceTriple)
1356  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1357  if (name)
1358  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1359  builder.getStringAttr(*name));
1360 }
1361 
1363  OperationState &result) {
1364  Region *body = result.addRegion();
1365 
1366  // If the name is present, parse it.
1367  StringAttr nameAttr;
1368  (void)parser.parseOptionalSymbolName(
1369  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1370 
1371  // Parse attributes
1372  spirv::AddressingModel addrModel;
1373  spirv::MemoryModel memoryModel;
1374  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1375  result) ||
1376  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1377  result))
1378  return failure();
1379 
1380  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1381  spirv::VerCapExtAttr vceTriple;
1382  if (parser.parseAttribute(vceTriple,
1383  spirv::ModuleOp::getVCETripleAttrName(),
1384  result.attributes))
1385  return failure();
1386  }
1387 
1388  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1389  parser.parseRegion(*body, /*arguments=*/{}))
1390  return failure();
1391 
1392  // Make sure we have at least one block.
1393  if (body->empty())
1394  body->push_back(new Block());
1395 
1396  return success();
1397 }
1398 
1399 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1400  if (std::optional<StringRef> name = getName()) {
1401  printer << ' ';
1402  printer.printSymbolName(*name);
1403  }
1404 
1405  SmallVector<StringRef, 2> elidedAttrs;
1406 
1407  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1408  << spirv::stringifyMemoryModel(getMemoryModel());
1409  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1410  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1411  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1413 
1414  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1415  printer << " requires " << *triple;
1416  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1417  }
1418 
1419  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1420  printer << ' ';
1421  printer.printRegion(getRegion());
1422 }
1423 
1424 LogicalResult spirv::ModuleOp::verifyRegions() {
1425  Dialect *dialect = (*this)->getDialect();
1427  entryPoints;
1428  mlir::SymbolTable table(*this);
1429 
1430  for (auto &op : *getBody()) {
1431  if (op.getDialect() != dialect)
1432  return op.emitError("'spirv.module' can only contain spirv.* ops");
1433 
1434  // For EntryPoint op, check that the function and execution model is not
1435  // duplicated in EntryPointOps. Also verify that the interface specified
1436  // comes from globalVariables here to make this check cheaper.
1437  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1438  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1439  if (!funcOp) {
1440  return entryPointOp.emitError("function '")
1441  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1442  }
1443  if (auto interface = entryPointOp.getInterface()) {
1444  for (Attribute varRef : interface) {
1445  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1446  if (!varSymRef) {
1447  return entryPointOp.emitError(
1448  "expected symbol reference for interface "
1449  "specification instead of '")
1450  << varRef;
1451  }
1452  auto variableOp =
1453  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1454  if (!variableOp) {
1455  return entryPointOp.emitError("expected spirv.GlobalVariable "
1456  "symbol reference instead of'")
1457  << varSymRef << "'";
1458  }
1459  }
1460  }
1461 
1462  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1463  funcOp, entryPointOp.getExecutionModel());
1464  auto entryPtIt = entryPoints.find(key);
1465  if (entryPtIt != entryPoints.end()) {
1466  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1467  }
1468  entryPoints[key] = entryPointOp;
1469  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1470  // If the function is external and does not have 'Import'
1471  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1472  // LinkageAttributes is used to import external functions.
1473  auto linkageAttr = funcOp.getLinkageAttributes();
1474  auto hasImportLinkage =
1475  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1476  spirv::LinkageType::Import);
1477  if (funcOp.isExternal() && !hasImportLinkage)
1478  return op.emitError(
1479  "'spirv.module' cannot contain external functions "
1480  "without 'Import' linkage_attributes (LinkageAttributes)");
1481 
1482  // TODO: move this check to spirv.func.
1483  for (auto &block : funcOp)
1484  for (auto &op : block) {
1485  if (op.getDialect() != dialect)
1486  return op.emitError(
1487  "functions in 'spirv.module' can only contain spirv.* ops");
1488  }
1489  }
1490  }
1491 
1492  return success();
1493 }
1494 
1495 //===----------------------------------------------------------------------===//
1496 // spirv.mlir.referenceof
1497 //===----------------------------------------------------------------------===//
1498 
1500  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1501  (*this)->getParentOp(), getSpecConstAttr());
1502  Type constType;
1503 
1504  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1505  if (specConstOp)
1506  constType = specConstOp.getDefaultValue().getType();
1507 
1508  auto specConstCompositeOp =
1509  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1510  if (specConstCompositeOp)
1511  constType = specConstCompositeOp.getType();
1512 
1513  if (!specConstOp && !specConstCompositeOp)
1514  return emitOpError(
1515  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1516 
1517  if (getReference().getType() != constType)
1518  return emitOpError("result type mismatch with the referenced "
1519  "specialization constant's type");
1520 
1521  return success();
1522 }
1523 
1524 //===----------------------------------------------------------------------===//
1525 // spirv.SpecConstant
1526 //===----------------------------------------------------------------------===//
1527 
1529  OperationState &result) {
1530  StringAttr nameAttr;
1531  Attribute valueAttr;
1532 
1533  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1534  result.attributes))
1535  return failure();
1536 
1537  // Parse optional spec_id.
1539  IntegerAttr specIdAttr;
1540  if (parser.parseLParen() ||
1541  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1542  parser.parseRParen())
1543  return failure();
1544  }
1545 
1546  if (parser.parseEqual() ||
1547  parser.parseAttribute(valueAttr, kDefaultValueAttrName,
1548  result.attributes))
1549  return failure();
1550 
1551  return success();
1552 }
1553 
1555  printer << ' ';
1556  printer.printSymbolName(getSymName());
1557  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1558  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1559  printer << " = " << getDefaultValue();
1560 }
1561 
1563  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1564  if (specID.getValue().isNegative())
1565  return emitOpError("SpecId cannot be negative");
1566 
1567  auto value = getDefaultValue();
1568  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1569  // Make sure bitwidth is allowed.
1570  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1571  return emitOpError("default value bitwidth disallowed");
1572  return success();
1573  }
1574  return emitOpError(
1575  "default value can only be a bool, integer, or float scalar");
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // spirv.VectorShuffle
1580 //===----------------------------------------------------------------------===//
1581 
1583  VectorType resultType = llvm::cast<VectorType>(getType());
1584 
1585  size_t numResultElements = resultType.getNumElements();
1586  if (numResultElements != getComponents().size())
1587  return emitOpError("result type element count (")
1588  << numResultElements
1589  << ") mismatch with the number of component selectors ("
1590  << getComponents().size() << ")";
1591 
1592  size_t totalSrcElements =
1593  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1594  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1595 
1596  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1597  uint32_t index = selector.getZExtValue();
1598  if (index >= totalSrcElements &&
1599  index != std::numeric_limits<uint32_t>().max())
1600  return emitOpError("component selector ")
1601  << index << " out of range: expected to be in [0, "
1602  << totalSrcElements << ") or 0xffffffff";
1603  }
1604  return success();
1605 }
1606 
1607 //===----------------------------------------------------------------------===//
1608 // spirv.MatrixTimesScalar
1609 //===----------------------------------------------------------------------===//
1610 
1612  Type elementType =
1613  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1616  [](auto matrixType) { return matrixType.getElementType(); })
1617  .Default([](Type) { return nullptr; });
1618 
1619  assert(elementType && "Unhandled type");
1620 
1621  // Check that the scalar type is the same as the matrix element type.
1622  if (getScalar().getType() != elementType)
1623  return emitOpError("input matrix components' type and scaling value must "
1624  "have the same type");
1625 
1626  return success();
1627 }
1628 
1629 //===----------------------------------------------------------------------===//
1630 // spirv.Transpose
1631 //===----------------------------------------------------------------------===//
1632 
1634  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1635  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1636 
1637  // Verify that the input and output matrices have correct shapes.
1638  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1639  return emitError("input matrix rows count must be equal to "
1640  "output matrix columns count");
1641 
1642  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1643  return emitError("input matrix columns count must be equal to "
1644  "output matrix rows count");
1645 
1646  // Verify that the input and output matrices have the same component type
1647  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1648  return emitError("input and output matrices must have the same "
1649  "component type");
1650 
1651  return success();
1652 }
1653 
1654 //===----------------------------------------------------------------------===//
1655 // spirv.MatrixTimesMatrix
1656 //===----------------------------------------------------------------------===//
1657 
1659  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1660  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1661  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1662 
1663  // left matrix columns' count and right matrix rows' count must be equal
1664  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1665  return emitError("left matrix columns' count must be equal to "
1666  "the right matrix rows' count");
1667 
1668  // right and result matrices columns' count must be the same
1669  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1670  return emitError(
1671  "right and result matrices must have equal columns' count");
1672 
1673  // right and result matrices component type must be the same
1674  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1675  return emitError("right and result matrices' component type must"
1676  " be the same");
1677 
1678  // left and result matrices component type must be the same
1679  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1680  return emitError("left and result matrices' component type"
1681  " must be the same");
1682 
1683  // left and result matrices rows count must be the same
1684  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1685  return emitError("left and result matrices must have equal rows' count");
1686 
1687  return success();
1688 }
1689 
1690 //===----------------------------------------------------------------------===//
1691 // spirv.SpecConstantComposite
1692 //===----------------------------------------------------------------------===//
1693 
1695  OperationState &result) {
1696 
1697  StringAttr compositeName;
1698  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1699  result.attributes))
1700  return failure();
1701 
1702  if (parser.parseLParen())
1703  return failure();
1704 
1705  SmallVector<Attribute, 4> constituents;
1706 
1707  do {
1708  // The name of the constituent attribute isn't important
1709  const char *attrName = "spec_const";
1710  FlatSymbolRefAttr specConstRef;
1711  NamedAttrList attrs;
1712 
1713  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1714  return failure();
1715 
1716  constituents.push_back(specConstRef);
1717  } while (!parser.parseOptionalComma());
1718 
1719  if (parser.parseRParen())
1720  return failure();
1721 
1723  parser.getBuilder().getArrayAttr(constituents));
1724 
1725  Type type;
1726  if (parser.parseColonType(type))
1727  return failure();
1728 
1729  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
1730 
1731  return success();
1732 }
1733 
1735  printer << " ";
1736  printer.printSymbolName(getSymName());
1737  printer << " (";
1738  auto constituents = this->getConstituents().getValue();
1739 
1740  if (!constituents.empty())
1741  llvm::interleaveComma(constituents, printer);
1742 
1743  printer << ") : " << getType();
1744 }
1745 
1747  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1748  auto constituents = this->getConstituents().getValue();
1749 
1750  if (!cType)
1751  return emitError("result type must be a composite type, but provided ")
1752  << getType();
1753 
1754  if (llvm::isa<spirv::CooperativeMatrixNVType>(cType))
1755  return emitError("unsupported composite type ") << cType;
1756  if (llvm::isa<spirv::JointMatrixINTELType>(cType))
1757  return emitError("unsupported composite type ") << cType;
1758  if (constituents.size() != cType.getNumElements())
1759  return emitError("has incorrect number of operands: expected ")
1760  << cType.getNumElements() << ", but provided "
1761  << constituents.size();
1762 
1763  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1764  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1765 
1766  auto constituentSpecConstOp =
1767  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1768  (*this)->getParentOp(), constituent.getAttr()));
1769 
1770  if (constituentSpecConstOp.getDefaultValue().getType() !=
1771  cType.getElementType(index))
1772  return emitError("has incorrect types of operands: expected ")
1773  << cType.getElementType(index) << ", but provided "
1774  << constituentSpecConstOp.getDefaultValue().getType();
1775  }
1776 
1777  return success();
1778 }
1779 
1780 //===----------------------------------------------------------------------===//
1781 // spirv.SpecConstantOperation
1782 //===----------------------------------------------------------------------===//
1783 
1785  OperationState &result) {
1786  Region *body = result.addRegion();
1787 
1788  if (parser.parseKeyword("wraps"))
1789  return failure();
1790 
1791  body->push_back(new Block);
1792  Block &block = body->back();
1793  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1794 
1795  if (!wrappedOp)
1796  return failure();
1797 
1798  OpBuilder builder(parser.getContext());
1799  builder.setInsertionPointToEnd(&block);
1800  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
1801  result.location = wrappedOp->getLoc();
1802 
1803  result.addTypes(wrappedOp->getResult(0).getType());
1804 
1805  if (parser.parseOptionalAttrDict(result.attributes))
1806  return failure();
1807 
1808  return success();
1809 }
1810 
1812  printer << " wraps ";
1813  printer.printGenericOp(&getBody().front().front());
1814 }
1815 
1816 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1817  Block &block = getRegion().getBlocks().front();
1818 
1819  if (block.getOperations().size() != 2)
1820  return emitOpError("expected exactly 2 nested ops");
1821 
1822  Operation &enclosedOp = block.getOperations().front();
1823 
1825  return emitOpError("invalid enclosed op");
1826 
1827  for (auto operand : enclosedOp.getOperands())
1828  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1829  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1830  return emitOpError(
1831  "invalid operand, must be defined by a constant operation");
1832 
1833  return success();
1834 }
1835 
1836 //===----------------------------------------------------------------------===//
1837 // spirv.GL.FrexpStruct
1838 //===----------------------------------------------------------------------===//
1839 
1841  spirv::StructType structTy =
1842  llvm::dyn_cast<spirv::StructType>(getResult().getType());
1843 
1844  if (structTy.getNumElements() != 2)
1845  return emitError("result type must be a struct type with two memebers");
1846 
1847  Type significandTy = structTy.getElementType(0);
1848  Type exponentTy = structTy.getElementType(1);
1849  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1850  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1851 
1852  Type operandTy = getOperand().getType();
1853  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1854  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1855 
1856  if (significandTy != operandTy)
1857  return emitError("member zero of the resulting struct type must be the "
1858  "same type as the operand");
1859 
1860  if (exponentVecTy) {
1861  IntegerType componentIntTy =
1862  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1863  if (!componentIntTy || componentIntTy.getWidth() != 32)
1864  return emitError("member one of the resulting struct type must"
1865  "be a scalar or vector of 32 bit integer type");
1866  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1867  return emitError("member one of the resulting struct type "
1868  "must be a scalar or vector of 32 bit integer type");
1869  }
1870 
1871  // Check that the two member types have the same number of components
1872  if (operandVecTy && exponentVecTy &&
1873  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1874  return success();
1875 
1876  if (operandFTy && exponentIntTy)
1877  return success();
1878 
1879  return emitError("member one of the resulting struct type must have the same "
1880  "number of components as the operand type");
1881 }
1882 
1883 //===----------------------------------------------------------------------===//
1884 // spirv.GL.Ldexp
1885 //===----------------------------------------------------------------------===//
1886 
1888  Type significandType = getX().getType();
1889  Type exponentType = getExp().getType();
1890 
1891  if (llvm::isa<FloatType>(significandType) !=
1892  llvm::isa<IntegerType>(exponentType))
1893  return emitOpError("operands must both be scalars or vectors");
1894 
1895  auto getNumElements = [](Type type) -> unsigned {
1896  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
1897  return vectorType.getNumElements();
1898  return 1;
1899  };
1900 
1901  if (getNumElements(significandType) != getNumElements(exponentType))
1902  return emitOpError("operands must have the same number of elements");
1903 
1904  return success();
1905 }
1906 
1907 //===----------------------------------------------------------------------===//
1908 // spirv.ImageDrefGather
1909 //===----------------------------------------------------------------------===//
1910 
1912  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
1913  auto sampledImageType =
1914  llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
1915  auto imageType =
1916  llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
1917 
1918  if (resultType.getNumElements() != 4)
1919  return emitOpError("result type must be a vector of four components");
1920 
1921  Type elementType = resultType.getElementType();
1922  Type sampledElementType = imageType.getElementType();
1923  if (!llvm::isa<NoneType>(sampledElementType) &&
1924  elementType != sampledElementType)
1925  return emitOpError(
1926  "the component type of result must be the same as sampled type of the "
1927  "underlying image type");
1928 
1929  spirv::Dim imageDim = imageType.getDim();
1930  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
1931 
1932  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
1933  imageDim != spirv::Dim::Rect)
1934  return emitOpError(
1935  "the Dim operand of the underlying image type must be 2D, Cube, or "
1936  "Rect");
1937 
1938  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
1939  return emitOpError("the MS operand of the underlying image type must be 0");
1940 
1941  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
1942  auto operandArguments = getOperandArguments();
1943 
1944  return verifyImageOperands(*this, attr, operandArguments);
1945 }
1946 
1947 //===----------------------------------------------------------------------===//
1948 // spirv.ShiftLeftLogicalOp
1949 //===----------------------------------------------------------------------===//
1950 
1952  return verifyShiftOp(*this);
1953 }
1954 
1955 //===----------------------------------------------------------------------===//
1956 // spirv.ShiftRightArithmeticOp
1957 //===----------------------------------------------------------------------===//
1958 
1960  return verifyShiftOp(*this);
1961 }
1962 
1963 //===----------------------------------------------------------------------===//
1964 // spirv.ShiftRightLogicalOp
1965 //===----------------------------------------------------------------------===//
1966 
1968  return verifyShiftOp(*this);
1969 }
1970 
1971 //===----------------------------------------------------------------------===//
1972 // spirv.BtiwiseAndOp
1973 //===----------------------------------------------------------------------===//
1974 
1976 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1977  APInt rhsMask;
1978  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
1979  return {};
1980 
1981  // x & 0 -> 0
1982  if (rhsMask.isZero())
1983  return getOperand2();
1984 
1985  // x & <all ones> -> x
1986  if (rhsMask.isAllOnes())
1987  return getOperand1();
1988 
1989  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1990  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1991  int valueBits =
1993  if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1994  return getOperand1();
1995  }
1996 
1997  return {};
1998 }
1999 
2000 //===----------------------------------------------------------------------===//
2001 // spirv.BtiwiseOrOp
2002 //===----------------------------------------------------------------------===//
2003 
2004 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
2005  APInt rhsMask;
2006  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
2007  return {};
2008 
2009  // x | 0 -> x
2010  if (rhsMask.isZero())
2011  return getOperand1();
2012 
2013  // x | <all ones> -> <all ones>
2014  if (rhsMask.isAllOnes())
2015  return getOperand2();
2016 
2017  return {};
2018 }
2019 
2020 //===----------------------------------------------------------------------===//
2021 // spirv.ImageQuerySize
2022 //===----------------------------------------------------------------------===//
2023 
2025  spirv::ImageType imageType =
2026  llvm::cast<spirv::ImageType>(getImage().getType());
2027  Type resultType = getResult().getType();
2028 
2029  spirv::Dim dim = imageType.getDim();
2030  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
2031  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
2032  switch (dim) {
2033  case spirv::Dim::Dim1D:
2034  case spirv::Dim::Dim2D:
2035  case spirv::Dim::Dim3D:
2036  case spirv::Dim::Cube:
2037  if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2038  samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2039  samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2040  return emitError(
2041  "if Dim is 1D, 2D, 3D, or Cube, "
2042  "it must also have either an MS of 1 or a Sampled of 0 or 2");
2043  break;
2044  case spirv::Dim::Buffer:
2045  case spirv::Dim::Rect:
2046  break;
2047  default:
2048  return emitError("the Dim operand of the image type must "
2049  "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2050  }
2051 
2052  unsigned componentNumber = 0;
2053  switch (dim) {
2054  case spirv::Dim::Dim1D:
2055  case spirv::Dim::Buffer:
2056  componentNumber = 1;
2057  break;
2058  case spirv::Dim::Dim2D:
2059  case spirv::Dim::Cube:
2060  case spirv::Dim::Rect:
2061  componentNumber = 2;
2062  break;
2063  case spirv::Dim::Dim3D:
2064  componentNumber = 3;
2065  break;
2066  default:
2067  break;
2068  }
2069 
2070  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2071  componentNumber += 1;
2072 
2073  unsigned resultComponentNumber = 1;
2074  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2075  resultComponentNumber = resultVectorType.getNumElements();
2076 
2077  if (componentNumber != resultComponentNumber)
2078  return emitError("expected the result to have ")
2079  << componentNumber << " component(s), but found "
2080  << resultComponentNumber << " component(s)";
2081 
2082  return success();
2083 }
2084 
2085 //===----------------------------------------------------------------------===//
2086 // spirv.VectorTimesScalarOp
2087 //===----------------------------------------------------------------------===//
2088 
2090  if (getVector().getType() != getType())
2091  return emitOpError("vector operand and result type mismatch");
2092  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2093  if (getScalar().getType() != scalarType)
2094  return emitOpError("scalar operand and result element type match");
2095  return success();
2096 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:297
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:583
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:122
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:283
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:329
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:199
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:151
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
Definition: SPIRVOps.cpp:171
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:321
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1352
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType & getOperations()
Definition: Block.h:130
Operation & front()
Definition: Block.h:146
iterator begin()
Definition: Block.h:136
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:283
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:775
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
ImageArrayedInfo getArrayedInfo() const
Definition: SPIRVTypes.cpp:491
ImageSamplerUseInfo getSamplerUseInfo() const
Definition: SPIRVTypes.cpp:499
ImageSamplingInfo getSamplingInfo() const
Definition: SPIRVTypes.cpp:495
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:546
SPIR-V struct type.
Definition: SPIRVTypes.h:284
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
constexpr char kFnNameAttrName[]
constexpr char kInitializerAttrName[]
constexpr char kIndicesAttrName[]
constexpr char kSpecIdAttrName[]
constexpr char kValuesAttrName[]
constexpr char kValueAttrName[]
constexpr char kCompositeSpecConstituentsName[]
constexpr char kInterfaceAttrName[]
constexpr char kTypeAttrName[]
constexpr char kDefaultValueAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:71
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:95
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:51
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.