MLIR  22.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/TypeUtilities.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/InterleavedRange.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 
42 using namespace mlir;
43 using namespace mlir::spirv::AttrNames;
44 
45 //===----------------------------------------------------------------------===//
46 // Common utility functions
47 //===----------------------------------------------------------------------===//
48 
49 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51  if (!constOp) {
52  return failure();
53  }
54  auto valueAttr = constOp.getValue();
55  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
56  if (!integerValueAttr) {
57  return failure();
58  }
59 
60  if (integerValueAttr.getType().isSignlessInteger())
61  value = integerValueAttr.getInt();
62  else
63  value = integerValueAttr.getSInt();
64 
65  return success();
66 }
67 
68 LogicalResult
70  spirv::MemorySemantics memorySemantics) {
71  // According to the SPIR-V specification:
72  // "Despite being a mask and allowing multiple bits to be combined, it is
73  // invalid for more than one of these four bits to be set: Acquire, Release,
74  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75  // Release semantics is done by setting the AcquireRelease bit, not by setting
76  // two bits."
77  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78  spirv::MemorySemantics::Release |
79  spirv::MemorySemantics::AcquireRelease |
80  spirv::MemorySemantics::SequentiallyConsistent;
81 
82  auto bitCount =
83  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84  if (bitCount > 1) {
85  return op->emitError(
86  "expected at most one of these four memory constraints "
87  "to be set: `Acquire`, `Release`,"
88  "`AcquireRelease` or `SequentiallyConsistent`");
89  }
90  return success();
91 }
92 
94  SmallVectorImpl<StringRef> &elidedAttrs) {
95  // Print optional descriptor binding
96  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97  stringifyDecoration(spirv::Decoration::DescriptorSet));
98  auto bindingName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::Binding));
100  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102  if (descriptorSet && binding) {
103  elidedAttrs.push_back(descriptorSetName);
104  elidedAttrs.push_back(bindingName);
105  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106  << ")";
107  }
108 
109  // Print BuiltIn attribute if present
110  auto builtInName = llvm::convertToSnakeFromCamelCase(
111  stringifyDecoration(spirv::Decoration::BuiltIn));
112  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114  elidedAttrs.push_back(builtInName);
115  }
116 
117  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118 }
119 
120 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
121  OperationState &result) {
123  Type type;
124  // If the operand list is in-between parentheses, then we have a generic form.
125  // (see the fallback in `printOneResultOp`).
126  SMLoc loc = parser.getCurrentLocation();
127  if (!parser.parseOptionalLParen()) {
128  if (parser.parseOperandList(ops) || parser.parseRParen() ||
129  parser.parseOptionalAttrDict(result.attributes) ||
130  parser.parseColon() || parser.parseType(type))
131  return failure();
132  auto fnType = llvm::dyn_cast<FunctionType>(type);
133  if (!fnType) {
134  parser.emitError(loc, "expected function type");
135  return failure();
136  }
137  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138  return failure();
139  result.addTypes(fnType.getResults());
140  return success();
141  }
142  return failure(parser.parseOperandList(ops) ||
143  parser.parseOptionalAttrDict(result.attributes) ||
144  parser.parseColonType(type) ||
145  parser.resolveOperands(ops, type, result.operands) ||
146  parser.addTypeToList(type, result.types));
147 }
148 
150  assert(op->getNumResults() == 1 && "op should have one result");
151 
152  // If not all the operand and result types are the same, just use the
153  // generic assembly form to avoid omitting information in printing.
154  auto resultType = op->getResult(0).getType();
155  if (llvm::any_of(op->getOperandTypes(),
156  [&](Type type) { return type != resultType; })) {
157  p.printGenericOp(op, /*printOpName=*/false);
158  return;
159  }
160 
161  p << ' ';
162  p.printOperands(op->getOperands());
164  // Now we can output only one type for all operands and the result.
165  p << " : " << resultType;
166 }
167 
168 template <typename BlockReadWriteOpTy>
169 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170  Value ptr, Value val) {
171  auto valType = val.getType();
172  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
173  valType = valVecTy.getElementType();
174 
175  if (valType !=
176  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
177  return op.emitOpError("mismatch in result type and pointer type");
178  }
179  return success();
180 }
181 
182 /// Walks the given type hierarchy with the given indices, potentially down
183 /// to component granularity, to select an element type. Returns null type and
184 /// emits errors with the given loc on failure.
185 static Type
187  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
188  if (indices.empty()) {
189  emitErrorFn("expected at least one index for spirv.CompositeExtract");
190  return nullptr;
191  }
192 
193  for (auto index : indices) {
194  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
195  if (cType.hasCompileTimeKnownNumElements() &&
196  (index < 0 ||
197  static_cast<uint64_t>(index) >= cType.getNumElements())) {
198  emitErrorFn("index ") << index << " out of bounds for " << type;
199  return nullptr;
200  }
201  type = cType.getElementType(index);
202  } else {
203  emitErrorFn("cannot extract from non-composite type ")
204  << type << " with index " << index;
205  return nullptr;
206  }
207  }
208  return type;
209 }
210 
211 static Type
213  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
214  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
215  if (!indicesArrayAttr) {
216  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
217  return nullptr;
218  }
219  if (indicesArrayAttr.empty()) {
220  emitErrorFn("expected at least one index for spirv.CompositeExtract");
221  return nullptr;
222  }
223 
224  SmallVector<int32_t, 2> indexVals;
225  for (auto indexAttr : indicesArrayAttr) {
226  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
227  if (!indexIntAttr) {
228  emitErrorFn("expected an 32-bit integer for index, but found '")
229  << indexAttr << "'";
230  return nullptr;
231  }
232  indexVals.push_back(indexIntAttr.getInt());
233  }
234  return getElementType(type, indexVals, emitErrorFn);
235 }
236 
237 static Type getElementType(Type type, Attribute indices, Location loc) {
238  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
239  return ::mlir::emitError(loc, err);
240  };
241  return getElementType(type, indices, errorFn);
242 }
243 
244 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
245  SMLoc loc) {
246  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
247  return parser.emitError(loc, err);
248  };
249  return getElementType(type, indices, errorFn);
250 }
251 
252 template <typename ExtendedBinaryOp>
253 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
254  auto resultType = llvm::cast<spirv::StructType>(op.getType());
255  if (resultType.getNumElements() != 2)
256  return op.emitOpError("expected result struct type containing two members");
257 
258  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
259  resultType.getElementType(0),
260  resultType.getElementType(1)}))
261  return op.emitOpError(
262  "expected all operand types and struct member types are the same");
263 
264  return success();
265 }
266 
267 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
268  OperationState &result) {
270  if (parser.parseOptionalAttrDict(result.attributes) ||
271  parser.parseOperandList(operands) || parser.parseColon())
272  return failure();
273 
274  Type resultType;
275  SMLoc loc = parser.getCurrentLocation();
276  if (parser.parseType(resultType))
277  return failure();
278 
279  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
280  if (!structType || structType.getNumElements() != 2)
281  return parser.emitError(loc, "expected spirv.struct type with two members");
282 
283  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
284  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
285  return failure();
286 
287  result.addTypes(resultType);
288  return success();
289 }
290 
292  OpAsmPrinter &printer) {
293  printer << ' ';
294  printer.printOptionalAttrDict(op->getAttrs());
295  printer.printOperands(op->getOperands());
296  printer << " : " << op->getResultTypes().front();
297 }
298 
299 static LogicalResult verifyShiftOp(Operation *op) {
300  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
301  return op->emitError("expected the same type for the first operand and "
302  "result, but provided ")
303  << op->getOperand(0).getType() << " and "
304  << op->getResult(0).getType();
305  }
306  return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // spirv.mlir.addressof
311 //===----------------------------------------------------------------------===//
312 
313 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
314  spirv::GlobalVariableOp var) {
315  build(builder, state, var.getType(), SymbolRefAttr::get(var));
316 }
317 
318 LogicalResult spirv::AddressOfOp::verify() {
319  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
320  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
321  getVariableAttr()));
322  if (!varOp) {
323  return emitOpError("expected spirv.GlobalVariable symbol");
324  }
325  if (getPointer().getType() != varOp.getType()) {
326  return emitOpError(
327  "result type mismatch with the referenced global variable's type");
328  }
329  return success();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // spirv.CompositeConstruct
334 //===----------------------------------------------------------------------===//
335 
336 LogicalResult spirv::CompositeConstructOp::verify() {
337  operand_range constituents = this->getConstituents();
338 
339  // There are 4 cases with varying verification rules:
340  // 1. Cooperative Matrices (1 constituent)
341  // 2. Structs (1 constituent for each member)
342  // 3. Arrays (1 constituent for each array element)
343  // 4. Vectors (1 constituent (sub-)element for each vector element)
344 
345  auto coopElementType =
348  [](auto coopType) { return coopType.getElementType(); })
349  .Default([](Type) { return nullptr; });
350 
351  // Case 1. -- matrices.
352  if (coopElementType) {
353  if (constituents.size() != 1)
354  return emitOpError("has incorrect number of operands: expected ")
355  << "1, but provided " << constituents.size();
356  if (coopElementType != constituents.front().getType())
357  return emitOpError("operand type mismatch: expected operand type ")
358  << coopElementType << ", but provided "
359  << constituents.front().getType();
360  return success();
361  }
362 
363  // Case 2./3./4. -- number of constituents matches the number of elements.
364  auto cType = llvm::cast<spirv::CompositeType>(getType());
365  if (constituents.size() == cType.getNumElements()) {
366  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367  if (constituents[index].getType() != cType.getElementType(index)) {
368  return emitOpError("operand type mismatch: expected operand type ")
369  << cType.getElementType(index) << ", but provided "
370  << constituents[index].getType();
371  }
372  }
373  return success();
374  }
375 
376  // Case 4. -- check that all constituents add up tp the expected vector type.
377  auto resultType = llvm::dyn_cast<VectorType>(cType);
378  if (!resultType)
379  return emitOpError(
380  "expected to return a vector or cooperative matrix when the number of "
381  "constituents is less than what the result needs");
382 
383  SmallVector<unsigned> sizes;
384  for (Value component : constituents) {
385  if (!llvm::isa<VectorType>(component.getType()) &&
386  !component.getType().isIntOrFloat())
387  return emitOpError("operand type mismatch: expected operand to have "
388  "a scalar or vector type, but provided ")
389  << component.getType();
390 
391  Type elementType = component.getType();
392  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
393  sizes.push_back(vectorType.getNumElements());
394  elementType = vectorType.getElementType();
395  } else {
396  sizes.push_back(1);
397  }
398 
399  if (elementType != resultType.getElementType())
400  return emitOpError("operand element type mismatch: expected to be ")
401  << resultType.getElementType() << ", but provided " << elementType;
402  }
403  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
404  if (totalCount != cType.getNumElements())
405  return emitOpError("has incorrect number of operands: expected ")
406  << cType.getNumElements() << ", but provided " << totalCount;
407  return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // spirv.CompositeExtractOp
412 //===----------------------------------------------------------------------===//
413 
414 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
415  Value composite,
416  ArrayRef<int32_t> indices) {
417  auto indexAttr = builder.getI32ArrayAttr(indices);
418  auto elementType =
419  getElementType(composite.getType(), indexAttr, state.location);
420  if (!elementType) {
421  return;
422  }
423  build(builder, state, elementType, composite, indexAttr);
424 }
425 
427  OperationState &result) {
428  OpAsmParser::UnresolvedOperand compositeInfo;
429  Attribute indicesAttr;
430  StringRef indicesAttrName =
431  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
432  Type compositeType;
433  SMLoc attrLocation;
434 
435  if (parser.parseOperand(compositeInfo) ||
436  parser.getCurrentLocation(&attrLocation) ||
437  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
438  parser.parseColonType(compositeType) ||
439  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
440  return failure();
441  }
442 
443  Type resultType =
444  getElementType(compositeType, indicesAttr, parser, attrLocation);
445  if (!resultType) {
446  return failure();
447  }
448  result.addTypes(resultType);
449  return success();
450 }
451 
453  printer << ' ' << getComposite() << getIndices() << " : "
454  << getComposite().getType();
455 }
456 
457 LogicalResult spirv::CompositeExtractOp::verify() {
458  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
459  auto resultType =
460  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
461  if (!resultType)
462  return failure();
463 
464  if (resultType != getType()) {
465  return emitOpError("invalid result type: expected ")
466  << resultType << " but provided " << getType();
467  }
468 
469  return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // spirv.CompositeInsert
474 //===----------------------------------------------------------------------===//
475 
476 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
477  Value object, Value composite,
478  ArrayRef<int32_t> indices) {
479  auto indexAttr = builder.getI32ArrayAttr(indices);
480  build(builder, state, composite.getType(), object, composite, indexAttr);
481 }
482 
483 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
484  OperationState &result) {
486  Type objectType, compositeType;
487  Attribute indicesAttr;
488  StringRef indicesAttrName =
489  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
490  auto loc = parser.getCurrentLocation();
491 
492  return failure(
493  parser.parseOperandList(operands, 2) ||
494  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
495  parser.parseColonType(objectType) ||
496  parser.parseKeywordType("into", compositeType) ||
497  parser.resolveOperands(operands, {objectType, compositeType}, loc,
498  result.operands) ||
499  parser.addTypesToList(compositeType, result.types));
500 }
501 
502 LogicalResult spirv::CompositeInsertOp::verify() {
503  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
504  auto objectType =
505  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
506  if (!objectType)
507  return failure();
508 
509  if (objectType != getObject().getType()) {
510  return emitOpError("object operand type should be ")
511  << objectType << ", but found " << getObject().getType();
512  }
513 
514  if (getComposite().getType() != getType()) {
515  return emitOpError("result type should be the same as "
516  "the composite type, but found ")
517  << getComposite().getType() << " vs " << getType();
518  }
519 
520  return success();
521 }
522 
524  printer << " " << getObject() << ", " << getComposite() << getIndices()
525  << " : " << getObject().getType() << " into "
526  << getComposite().getType();
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // spirv.Constant
531 //===----------------------------------------------------------------------===//
532 
533 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534  OperationState &result) {
535  Attribute value;
536  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
537  if (parser.parseAttribute(value, valueAttrName, result.attributes))
538  return failure();
539 
540  Type type = NoneType::get(parser.getContext());
541  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
542  type = typedAttr.getType();
543  if (llvm::isa<NoneType, TensorType>(type)) {
544  if (parser.parseColonType(type))
545  return failure();
546  }
547 
548  if (llvm::isa<TensorArmType>(type)) {
549  if (parser.parseOptionalColon().succeeded())
550  if (parser.parseType(type))
551  return failure();
552  }
553 
554  return parser.addTypeToList(type, result.types);
555 }
556 
557 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
558  printer << ' ' << getValue();
559  if (llvm::isa<spirv::ArrayType>(getType()))
560  printer << " : " << getType();
561 }
562 
563 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
564  Type opType) {
565  if (isa<spirv::CooperativeMatrixType>(opType)) {
566  auto denseAttr = dyn_cast<DenseElementsAttr>(value);
567  if (!denseAttr || !denseAttr.isSplat())
568  return op.emitOpError("expected a splat dense attribute for cooperative "
569  "matrix constant, but found ")
570  << denseAttr;
571  }
572  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
573  auto valueType = llvm::cast<TypedAttr>(value).getType();
574  if (valueType != opType)
575  return op.emitOpError("result type (")
576  << opType << ") does not match value type (" << valueType << ")";
577  return success();
578  }
579  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
580  auto valueType = llvm::cast<TypedAttr>(value).getType();
581  if (valueType == opType)
582  return success();
583  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
584  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
585  if (!arrayType)
586  return op.emitOpError("result or element type (")
587  << opType << ") does not match value type (" << valueType
588  << "), must be the same or spirv.array";
589 
590  int numElements = arrayType.getNumElements();
591  auto opElemType = arrayType.getElementType();
592  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
593  numElements *= t.getNumElements();
594  opElemType = t.getElementType();
595  }
596  if (!opElemType.isIntOrFloat())
597  return op.emitOpError("only support nested array result type");
598 
599  auto valueElemType = shapedType.getElementType();
600  if (valueElemType != opElemType) {
601  return op.emitOpError("result element type (")
602  << opElemType << ") does not match value element type ("
603  << valueElemType << ")";
604  }
605 
606  if (numElements != shapedType.getNumElements()) {
607  return op.emitOpError("result number of elements (")
608  << numElements << ") does not match value number of elements ("
609  << shapedType.getNumElements() << ")";
610  }
611  return success();
612  }
613  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
614  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
615  if (!arrayType)
616  return op.emitOpError(
617  "must have spirv.array result type for array value");
618  Type elemType = arrayType.getElementType();
619  for (Attribute element : arrayAttr.getValue()) {
620  // Verify array elements recursively.
621  if (failed(verifyConstantType(op, element, elemType)))
622  return failure();
623  }
624  return success();
625  }
626  return op.emitOpError("cannot have attribute: ") << value;
627 }
628 
629 LogicalResult spirv::ConstantOp::verify() {
630  // ODS already generates checks to make sure the result type is valid. We just
631  // need to additionally check that the value's attribute type is consistent
632  // with the result type.
633  return verifyConstantType(*this, getValueAttr(), getType());
634 }
635 
636 bool spirv::ConstantOp::isBuildableWith(Type type) {
637  // Must be valid SPIR-V type first.
638  if (!llvm::isa<spirv::SPIRVType>(type))
639  return false;
640 
641  if (isa<SPIRVDialect>(type.getDialect())) {
642  // TODO: support constant struct
643  return llvm::isa<spirv::ArrayType>(type);
644  }
645 
646  return true;
647 }
648 
649 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
650  OpBuilder &builder) {
651  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
652  unsigned width = intType.getWidth();
653  if (width == 1)
654  return spirv::ConstantOp::create(builder, loc, type,
655  builder.getBoolAttr(false));
656  return spirv::ConstantOp::create(
657  builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
658  }
659  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
660  return spirv::ConstantOp::create(builder, loc, type,
661  builder.getFloatAttr(floatType, 0.0));
662  }
663  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
664  Type elemType = vectorType.getElementType();
665  if (llvm::isa<IntegerType>(elemType)) {
666  return spirv::ConstantOp::create(
667  builder, loc, type,
668  DenseElementsAttr::get(vectorType,
669  IntegerAttr::get(elemType, 0).getValue()));
670  }
671  if (llvm::isa<FloatType>(elemType)) {
672  return spirv::ConstantOp::create(
673  builder, loc, type,
674  DenseFPElementsAttr::get(vectorType,
675  FloatAttr::get(elemType, 0.0).getValue()));
676  }
677  }
678 
679  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
680 }
681 
682 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
683  OpBuilder &builder) {
684  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
685  unsigned width = intType.getWidth();
686  if (width == 1)
687  return spirv::ConstantOp::create(builder, loc, type,
688  builder.getBoolAttr(true));
689  return spirv::ConstantOp::create(
690  builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
691  }
692  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
693  return spirv::ConstantOp::create(builder, loc, type,
694  builder.getFloatAttr(floatType, 1.0));
695  }
696  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
697  Type elemType = vectorType.getElementType();
698  if (llvm::isa<IntegerType>(elemType)) {
699  return spirv::ConstantOp::create(
700  builder, loc, type,
701  DenseElementsAttr::get(vectorType,
702  IntegerAttr::get(elemType, 1).getValue()));
703  }
704  if (llvm::isa<FloatType>(elemType)) {
705  return spirv::ConstantOp::create(
706  builder, loc, type,
707  DenseFPElementsAttr::get(vectorType,
708  FloatAttr::get(elemType, 1.0).getValue()));
709  }
710  }
711 
712  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
713 }
714 
715 void mlir::spirv::ConstantOp::getAsmResultNames(
716  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
717  Type type = getType();
718 
719  SmallString<32> specialNameBuffer;
720  llvm::raw_svector_ostream specialName(specialNameBuffer);
721  specialName << "cst";
722 
723  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
724 
725  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
726  assert(intTy);
727 
728  if (intTy.getWidth() == 1) {
729  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
730  }
731 
732  if (intTy.isSignless()) {
733  specialName << intCst.getInt();
734  } else if (intTy.isUnsigned()) {
735  specialName << intCst.getUInt();
736  } else {
737  specialName << intCst.getSInt();
738  }
739  }
740 
741  if (intTy || llvm::isa<FloatType>(type)) {
742  specialName << '_' << type;
743  }
744 
745  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
746  specialName << "_vec_";
747  specialName << vecType.getDimSize(0);
748 
749  Type elementType = vecType.getElementType();
750 
751  if (llvm::isa<IntegerType>(elementType) ||
752  llvm::isa<FloatType>(elementType)) {
753  specialName << "x" << elementType;
754  }
755  }
756 
757  setNameFn(getResult(), specialName.str());
758 }
759 
760 void mlir::spirv::AddressOfOp::getAsmResultNames(
761  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
762  SmallString<32> specialNameBuffer;
763  llvm::raw_svector_ostream specialName(specialNameBuffer);
764  specialName << getVariable() << "_addr";
765  setNameFn(getResult(), specialName.str());
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // spirv.EXTConstantCompositeReplicate
770 //===----------------------------------------------------------------------===//
771 
772 // Returns type of attribute. In case of a TypedAttr this will simply return
773 // the type. But for an ArrayAttr which is untyped and can be multidimensional
774 // it creates the ArrayType recursively.
776  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
777  return typedAttr.getType();
778  }
779 
780  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
781  return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
782  }
783 
784  return nullptr;
785 }
786 
788  Type valueType = getValueType(getValue());
789  if (!valueType)
790  return emitError("unknown value attribute type");
791 
792  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
793  if (!compositeType)
794  return emitError("result type is not a composite type");
795 
796  Type compositeElementType = compositeType.getElementType(0);
797 
798  SmallVector<Type, 3> possibleTypes = {compositeElementType};
799  while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
800  compositeElementType = type.getElementType(0);
801  possibleTypes.push_back(compositeElementType);
802  }
803 
804  if (!is_contained(possibleTypes, valueType)) {
805  return emitError("expected value attribute type ")
806  << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
807  }
808 
809  return success();
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // spirv.ControlBarrierOp
814 //===----------------------------------------------------------------------===//
815 
816 LogicalResult spirv::ControlBarrierOp::verify() {
817  return verifyMemorySemantics(getOperation(), getMemorySemantics());
818 }
819 
820 //===----------------------------------------------------------------------===//
821 // spirv.EntryPoint
822 //===----------------------------------------------------------------------===//
823 
824 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
825  spirv::ExecutionModel executionModel,
826  spirv::FuncOp function,
827  ArrayRef<Attribute> interfaceVars) {
828  build(builder, state,
829  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
830  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
831 }
832 
833 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
834  OperationState &result) {
835  spirv::ExecutionModel execModel;
836  SmallVector<Attribute, 4> interfaceVars;
837 
839  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
840  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
841  return failure();
842  }
843 
844  if (!parser.parseOptionalComma()) {
845  // Parse the interface variables
846  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
847  // The name of the interface variable attribute isnt important
848  FlatSymbolRefAttr var;
849  NamedAttrList attrs;
850  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
851  return failure();
852  interfaceVars.push_back(var);
853  return success();
854  }))
855  return failure();
856  }
857  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
858  parser.getBuilder().getArrayAttr(interfaceVars));
859  return success();
860 }
861 
863  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
864  printer.printSymbolName(getFn());
865  auto interfaceVars = getInterface().getValue();
866  if (!interfaceVars.empty())
867  printer << ", " << llvm::interleaved(interfaceVars);
868 }
869 
870 LogicalResult spirv::EntryPointOp::verify() {
871  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
872  // verification.
873  return success();
874 }
875 
876 //===----------------------------------------------------------------------===//
877 // spirv.ExecutionMode
878 //===----------------------------------------------------------------------===//
879 
880 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
881  spirv::FuncOp function,
882  spirv::ExecutionMode executionMode,
883  ArrayRef<int32_t> params) {
884  build(builder, state, SymbolRefAttr::get(function),
885  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
886  builder.getI32ArrayAttr(params));
887 }
888 
889 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
890  OperationState &result) {
891  spirv::ExecutionMode execMode;
892  Attribute fn;
893  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
894  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
895  return failure();
896  }
897 
899  Type i32Type = parser.getBuilder().getIntegerType(32);
900  while (!parser.parseOptionalComma()) {
901  NamedAttrList attr;
902  Attribute value;
903  if (parser.parseAttribute(value, i32Type, "value", attr)) {
904  return failure();
905  }
906  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
907  }
908  StringRef valuesAttrName =
909  spirv::ExecutionModeOp::getValuesAttrName(result.name);
910  result.addAttribute(valuesAttrName,
911  parser.getBuilder().getI32ArrayAttr(values));
912  return success();
913 }
914 
916  printer << " ";
917  printer.printSymbolName(getFn());
918  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
919  ArrayAttr values = this->getValues();
920  if (!values.empty())
921  printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // spirv.func
926 //===----------------------------------------------------------------------===//
927 
928 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
930  SmallVector<DictionaryAttr> resultAttrs;
931  SmallVector<Type> resultTypes;
932  auto &builder = parser.getBuilder();
933 
934  // Parse the name as a symbol.
935  StringAttr nameAttr;
936  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
937  result.attributes))
938  return failure();
939 
940  // Parse the function signature.
941  bool isVariadic = false;
943  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
944  resultAttrs))
945  return failure();
946 
947  SmallVector<Type> argTypes;
948  for (auto &arg : entryArgs)
949  argTypes.push_back(arg.type);
950  auto fnType = builder.getFunctionType(argTypes, resultTypes);
951  result.addAttribute(getFunctionTypeAttrName(result.name),
952  TypeAttr::get(fnType));
953 
954  // Parse the optional function control keyword.
955  spirv::FunctionControl fnControl;
956  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
957  return failure();
958 
959  // If additional attributes are present, parse them.
960  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
961  return failure();
962 
963  // Add the attributes to the function arguments.
964  assert(resultAttrs.size() == resultTypes.size());
966  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
967  getResAttrsAttrName(result.name));
968 
969  // Parse the optional function body.
970  auto *body = result.addRegion();
971  OptionalParseResult parseResult =
972  parser.parseOptionalRegion(*body, entryArgs);
973  return failure(parseResult.has_value() && failed(*parseResult));
974 }
975 
976 void spirv::FuncOp::print(OpAsmPrinter &printer) {
977  // Print function name, signature, and control.
978  printer << " ";
979  printer.printSymbolName(getSymName());
980  auto fnType = getFunctionType();
982  printer, *this, fnType.getInputs(),
983  /*isVariadic=*/false, fnType.getResults());
984  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
985  << "\"";
987  printer, *this,
988  {spirv::attributeName<spirv::FunctionControl>(),
989  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
990  getFunctionControlAttrName()});
991 
992  // Print the body if this is not an external function.
993  Region &body = this->getBody();
994  if (!body.empty()) {
995  printer << ' ';
996  printer.printRegion(body, /*printEntryBlockArgs=*/false,
997  /*printBlockTerminators=*/true);
998  }
999 }
1000 
1001 LogicalResult spirv::FuncOp::verifyType() {
1002  FunctionType fnType = getFunctionType();
1003  if (fnType.getNumResults() > 1)
1004  return emitOpError("cannot have more than one result");
1005 
1006  auto hasDecorationAttr = [&](spirv::Decoration decoration,
1007  unsigned argIndex) {
1008  auto func = llvm::cast<FunctionOpInterface>(getOperation());
1009  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1010  if (argAttr.getName() != spirv::DecorationAttr::name)
1011  continue;
1012  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1013  return decAttr.getValue() == decoration;
1014  }
1015  return false;
1016  };
1017 
1018  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1019  Type param = fnType.getInputs()[i];
1020  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1021  if (!inputPtrType)
1022  continue;
1023 
1024  auto pointeePtrType =
1025  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1026  if (pointeePtrType) {
1027  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1028  // > If an OpFunctionParameter is a pointer (or contains a pointer)
1029  // > and the type it points to is a pointer in the PhysicalStorageBuffer
1030  // > storage class, the function parameter must be decorated with exactly
1031  // > one of AliasedPointer or RestrictPointer.
1032  if (pointeePtrType.getStorageClass() !=
1033  spirv::StorageClass::PhysicalStorageBuffer)
1034  continue;
1035 
1036  bool hasAliasedPtr =
1037  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1038  bool hasRestrictPtr =
1039  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1040  if (!hasAliasedPtr && !hasRestrictPtr)
1041  return emitOpError()
1042  << "with a pointer points to a physical buffer pointer must "
1043  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1044  continue;
1045  }
1046  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1047  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1048  // > the PhysicalStorageBuffer storage class, the function parameter must
1049  // > be decorated with exactly one of Aliased or Restrict.
1050  if (auto pointeeArrayType =
1051  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1052  pointeePtrType =
1053  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1054  } else {
1055  pointeePtrType = inputPtrType;
1056  }
1057 
1058  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1059  spirv::StorageClass::PhysicalStorageBuffer)
1060  continue;
1061 
1062  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1063  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1064  if (!hasAliased && !hasRestrict)
1065  return emitOpError() << "with physical buffer pointer must be decorated "
1066  "either 'Aliased' or 'Restrict'";
1067  }
1068 
1069  return success();
1070 }
1071 
1072 LogicalResult spirv::FuncOp::verifyBody() {
1073  FunctionType fnType = getFunctionType();
1074  if (!isExternal()) {
1075  Block &entryBlock = front();
1076 
1077  unsigned numArguments = this->getNumArguments();
1078  if (entryBlock.getNumArguments() != numArguments)
1079  return emitOpError("entry block must have ")
1080  << numArguments << " arguments to match function signature";
1081 
1082  for (auto [index, fnArgType, blockArgType] :
1083  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1084  if (blockArgType != fnArgType) {
1085  return emitOpError("type of entry block argument #")
1086  << index << '(' << blockArgType
1087  << ") must match the type of the corresponding argument in "
1088  << "function signature(" << fnArgType << ')';
1089  }
1090  }
1091  }
1092 
1093  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1094  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1095  if (fnType.getNumResults() != 0)
1096  return retOp.emitOpError("cannot be used in functions returning value");
1097  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1098  if (fnType.getNumResults() != 1)
1099  return retOp.emitOpError(
1100  "returns 1 value but enclosing function requires ")
1101  << fnType.getNumResults() << " results";
1102 
1103  auto retOperandType = retOp.getValue().getType();
1104  auto fnResultType = fnType.getResult(0);
1105  if (retOperandType != fnResultType)
1106  return retOp.emitOpError(" return value's type (")
1107  << retOperandType << ") mismatch with function's result type ("
1108  << fnResultType << ")";
1109  }
1110  return WalkResult::advance();
1111  });
1112 
1113  // TODO: verify other bits like linkage type.
1114 
1115  return failure(walkResult.wasInterrupted());
1116 }
1117 
1118 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1119  StringRef name, FunctionType type,
1120  spirv::FunctionControl control,
1121  ArrayRef<NamedAttribute> attrs) {
1122  state.addAttribute(SymbolTable::getSymbolAttrName(),
1123  builder.getStringAttr(name));
1124  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1125  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1126  builder.getAttr<spirv::FunctionControlAttr>(control));
1127  state.attributes.append(attrs.begin(), attrs.end());
1128  state.addRegion();
1129 }
1130 
1131 //===----------------------------------------------------------------------===//
1132 // spirv.GLFClampOp
1133 //===----------------------------------------------------------------------===//
1134 
1135 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1136  OperationState &result) {
1137  return parseOneResultSameOperandTypeOp(parser, result);
1138 }
1140 
1141 //===----------------------------------------------------------------------===//
1142 // spirv.GLUClampOp
1143 //===----------------------------------------------------------------------===//
1144 
1145 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1146  OperationState &result) {
1147  return parseOneResultSameOperandTypeOp(parser, result);
1148 }
1150 
1151 //===----------------------------------------------------------------------===//
1152 // spirv.GLSClampOp
1153 //===----------------------------------------------------------------------===//
1154 
1155 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1156  OperationState &result) {
1157  return parseOneResultSameOperandTypeOp(parser, result);
1158 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // spirv.GLFmaOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1166  return parseOneResultSameOperandTypeOp(parser, result);
1167 }
1168 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // spirv.GlobalVariable
1172 //===----------------------------------------------------------------------===//
1173 
1174 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1175  Type type, StringRef name,
1176  unsigned descriptorSet, unsigned binding) {
1177  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1178  state.addAttribute(
1179  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1180  builder.getI32IntegerAttr(descriptorSet));
1181  state.addAttribute(
1182  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1183  builder.getI32IntegerAttr(binding));
1184 }
1185 
1186 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1187  Type type, StringRef name,
1188  spirv::BuiltIn builtin) {
1189  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1190  state.addAttribute(
1191  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1192  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1193 }
1194 
1195 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1196  OperationState &result) {
1197  // Parse variable name.
1198  StringAttr nameAttr;
1199  StringRef initializerAttrName =
1200  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1201  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1202  result.attributes)) {
1203  return failure();
1204  }
1205 
1206  // Parse optional initializer
1207  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1208  FlatSymbolRefAttr initSymbol;
1209  if (parser.parseLParen() ||
1210  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1211  result.attributes) ||
1212  parser.parseRParen())
1213  return failure();
1214  }
1215 
1216  if (parseVariableDecorations(parser, result)) {
1217  return failure();
1218  }
1219 
1220  Type type;
1221  StringRef typeAttrName =
1222  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1223  auto loc = parser.getCurrentLocation();
1224  if (parser.parseColonType(type)) {
1225  return failure();
1226  }
1227  if (!llvm::isa<spirv::PointerType>(type)) {
1228  return parser.emitError(loc, "expected spirv.ptr type");
1229  }
1230  result.addAttribute(typeAttrName, TypeAttr::get(type));
1231 
1232  return success();
1233 }
1234 
1236  SmallVector<StringRef, 4> elidedAttrs{
1237  spirv::attributeName<spirv::StorageClass>()};
1238 
1239  // Print variable name.
1240  printer << ' ';
1241  printer.printSymbolName(getSymName());
1242  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1243 
1244  StringRef initializerAttrName = this->getInitializerAttrName();
1245  // Print optional initializer
1246  if (auto initializer = this->getInitializer()) {
1247  printer << " " << initializerAttrName << '(';
1248  printer.printSymbolName(*initializer);
1249  printer << ')';
1250  elidedAttrs.push_back(initializerAttrName);
1251  }
1252 
1253  StringRef typeAttrName = this->getTypeAttrName();
1254  elidedAttrs.push_back(typeAttrName);
1255  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1256  printer << " : " << getType();
1257 }
1258 
1259 LogicalResult spirv::GlobalVariableOp::verify() {
1260  if (!llvm::isa<spirv::PointerType>(getType()))
1261  return emitOpError("result must be of a !spv.ptr type");
1262 
1263  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1264  // object. It cannot be Generic. It must be the same as the Storage Class
1265  // operand of the Result Type."
1266  // Also, Function storage class is reserved by spirv.Variable.
1267  auto storageClass = this->storageClass();
1268  if (storageClass == spirv::StorageClass::Generic ||
1269  storageClass == spirv::StorageClass::Function) {
1270  return emitOpError("storage class cannot be '")
1271  << stringifyStorageClass(storageClass) << "'";
1272  }
1273 
1274  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1275  this->getInitializerAttrName())) {
1277  (*this)->getParentOp(), init.getAttr());
1278  // TODO: Currently only variable initialization with specialization
1279  // constants and other variables is supported. They could be normal
1280  // constants in the module scope as well.
1281  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1282  spirv::SpecConstantCompositeOp>(initOp)) {
1283  return emitOpError("initializer must be result of a "
1284  "spirv.SpecConstant or spirv.GlobalVariable or "
1285  "spirv.SpecConstantCompositeOp op");
1286  }
1287  }
1288 
1289  return success();
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // spirv.INTEL.SubgroupBlockRead
1294 //===----------------------------------------------------------------------===//
1295 
1297  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1298  return failure();
1299 
1300  return success();
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // spirv.INTEL.SubgroupBlockWrite
1305 //===----------------------------------------------------------------------===//
1306 
1308  OperationState &result) {
1309  // Parse the storage class specification
1310  spirv::StorageClass storageClass;
1312  auto loc = parser.getCurrentLocation();
1313  Type elementType;
1314  if (parseEnumStrAttr(storageClass, parser) ||
1315  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1316  parser.parseType(elementType)) {
1317  return failure();
1318  }
1319 
1320  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1321  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1322  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1323 
1324  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1325  result.operands)) {
1326  return failure();
1327  }
1328  return success();
1329 }
1330 
1332  printer << " " << getPtr() << ", " << getValue() << " : "
1333  << getValue().getType();
1334 }
1335 
1337  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1338  return failure();
1339 
1340  return success();
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // spirv.IAddCarryOp
1345 //===----------------------------------------------------------------------===//
1346 
1347 LogicalResult spirv::IAddCarryOp::verify() {
1349 }
1350 
1351 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1352  OperationState &result) {
1354 }
1355 
1356 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1357  ::printArithmeticExtendedBinaryOp(*this, printer);
1358 }
1359 
1360 //===----------------------------------------------------------------------===//
1361 // spirv.ISubBorrowOp
1362 //===----------------------------------------------------------------------===//
1363 
1364 LogicalResult spirv::ISubBorrowOp::verify() {
1366 }
1367 
1368 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1369  OperationState &result) {
1371 }
1372 
1374  ::printArithmeticExtendedBinaryOp(*this, printer);
1375 }
1376 
1377 //===----------------------------------------------------------------------===//
1378 // spirv.SMulExtended
1379 //===----------------------------------------------------------------------===//
1380 
1381 LogicalResult spirv::SMulExtendedOp::verify() {
1383 }
1384 
1385 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1386  OperationState &result) {
1388 }
1389 
1391  ::printArithmeticExtendedBinaryOp(*this, printer);
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // spirv.UMulExtended
1396 //===----------------------------------------------------------------------===//
1397 
1398 LogicalResult spirv::UMulExtendedOp::verify() {
1400 }
1401 
1402 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1403  OperationState &result) {
1405 }
1406 
1408  ::printArithmeticExtendedBinaryOp(*this, printer);
1409 }
1410 
1411 //===----------------------------------------------------------------------===//
1412 // spirv.MemoryBarrierOp
1413 //===----------------------------------------------------------------------===//
1414 
1415 LogicalResult spirv::MemoryBarrierOp::verify() {
1416  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1417 }
1418 
1419 //===----------------------------------------------------------------------===//
1420 // spirv.module
1421 //===----------------------------------------------------------------------===//
1422 
1423 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1424  std::optional<StringRef> name) {
1425  OpBuilder::InsertionGuard guard(builder);
1426  builder.createBlock(state.addRegion());
1427  if (name) {
1428  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1429  builder.getStringAttr(*name));
1430  }
1431 }
1432 
1433 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1434  spirv::AddressingModel addressingModel,
1435  spirv::MemoryModel memoryModel,
1436  std::optional<VerCapExtAttr> vceTriple,
1437  std::optional<StringRef> name) {
1438  state.addAttribute(
1439  "addressing_model",
1440  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1441  state.addAttribute("memory_model",
1442  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1443  OpBuilder::InsertionGuard guard(builder);
1444  builder.createBlock(state.addRegion());
1445  if (vceTriple)
1446  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1447  if (name)
1448  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1449  builder.getStringAttr(*name));
1450 }
1451 
1452 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1453  OperationState &result) {
1454  Region *body = result.addRegion();
1455 
1456  // If the name is present, parse it.
1457  StringAttr nameAttr;
1458  (void)parser.parseOptionalSymbolName(
1459  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1460 
1461  // Parse attributes
1462  spirv::AddressingModel addrModel;
1463  spirv::MemoryModel memoryModel;
1464  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1465  result) ||
1466  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1467  result))
1468  return failure();
1469 
1470  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1471  spirv::VerCapExtAttr vceTriple;
1472  if (parser.parseAttribute(vceTriple,
1473  spirv::ModuleOp::getVCETripleAttrName(),
1474  result.attributes))
1475  return failure();
1476  }
1477 
1478  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1479  parser.parseRegion(*body, /*arguments=*/{}))
1480  return failure();
1481 
1482  // Make sure we have at least one block.
1483  if (body->empty())
1484  body->push_back(new Block());
1485 
1486  return success();
1487 }
1488 
1489 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1490  if (std::optional<StringRef> name = getName()) {
1491  printer << ' ';
1492  printer.printSymbolName(*name);
1493  }
1494 
1495  SmallVector<StringRef, 2> elidedAttrs;
1496 
1497  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1498  << spirv::stringifyMemoryModel(getMemoryModel());
1499  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1500  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1501  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1503 
1504  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1505  printer << " requires " << *triple;
1506  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1507  }
1508 
1509  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1510  printer << ' ';
1511  printer.printRegion(getRegion());
1512 }
1513 
1514 LogicalResult spirv::ModuleOp::verifyRegions() {
1515  Dialect *dialect = (*this)->getDialect();
1517  entryPoints;
1518  mlir::SymbolTable table(*this);
1519 
1520  for (auto &op : *getBody()) {
1521  if (op.getDialect() != dialect)
1522  return op.emitError("'spirv.module' can only contain spirv.* ops");
1523 
1524  // For EntryPoint op, check that the function and execution model is not
1525  // duplicated in EntryPointOps. Also verify that the interface specified
1526  // comes from globalVariables here to make this check cheaper.
1527  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1528  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1529  if (!funcOp) {
1530  return entryPointOp.emitError("function '")
1531  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1532  }
1533  if (auto interface = entryPointOp.getInterface()) {
1534  for (Attribute varRef : interface) {
1535  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1536  if (!varSymRef) {
1537  return entryPointOp.emitError(
1538  "expected symbol reference for interface "
1539  "specification instead of '")
1540  << varRef;
1541  }
1542  auto variableOp =
1543  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1544  if (!variableOp) {
1545  return entryPointOp.emitError("expected spirv.GlobalVariable "
1546  "symbol reference instead of'")
1547  << varSymRef << "'";
1548  }
1549  }
1550  }
1551 
1552  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1553  funcOp, entryPointOp.getExecutionModel());
1554  if (!entryPoints.try_emplace(key, entryPointOp).second)
1555  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1556  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1557  // If the function is external and does not have 'Import'
1558  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1559  // LinkageAttributes is used to import external functions.
1560  auto linkageAttr = funcOp.getLinkageAttributes();
1561  auto hasImportLinkage =
1562  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1563  spirv::LinkageType::Import);
1564  if (funcOp.isExternal() && !hasImportLinkage)
1565  return op.emitError(
1566  "'spirv.module' cannot contain external functions "
1567  "without 'Import' linkage_attributes (LinkageAttributes)");
1568 
1569  // TODO: move this check to spirv.func.
1570  for (auto &block : funcOp)
1571  for (auto &op : block) {
1572  if (op.getDialect() != dialect)
1573  return op.emitError(
1574  "functions in 'spirv.module' can only contain spirv.* ops");
1575  }
1576  }
1577  }
1578 
1579  return success();
1580 }
1581 
1582 //===----------------------------------------------------------------------===//
1583 // spirv.mlir.referenceof
1584 //===----------------------------------------------------------------------===//
1585 
1586 LogicalResult spirv::ReferenceOfOp::verify() {
1587  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1588  (*this)->getParentOp(), getSpecConstAttr());
1589  Type constType;
1590 
1591  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1592  if (specConstOp)
1593  constType = specConstOp.getDefaultValue().getType();
1594 
1595  auto specConstCompositeOp =
1596  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1597  if (specConstCompositeOp)
1598  constType = specConstCompositeOp.getType();
1599 
1600  if (!specConstOp && !specConstCompositeOp)
1601  return emitOpError(
1602  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1603 
1604  if (getReference().getType() != constType)
1605  return emitOpError("result type mismatch with the referenced "
1606  "specialization constant's type");
1607 
1608  return success();
1609 }
1610 
1611 //===----------------------------------------------------------------------===//
1612 // spirv.SpecConstant
1613 //===----------------------------------------------------------------------===//
1614 
1615 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1616  OperationState &result) {
1617  StringAttr nameAttr;
1618  Attribute valueAttr;
1619  StringRef defaultValueAttrName =
1620  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1621 
1622  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1623  result.attributes))
1624  return failure();
1625 
1626  // Parse optional spec_id.
1627  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1628  IntegerAttr specIdAttr;
1629  if (parser.parseLParen() ||
1630  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1631  parser.parseRParen())
1632  return failure();
1633  }
1634 
1635  if (parser.parseEqual() ||
1636  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1637  return failure();
1638 
1639  return success();
1640 }
1641 
1643  printer << ' ';
1644  printer.printSymbolName(getSymName());
1645  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1646  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1647  printer << " = " << getDefaultValue();
1648 }
1649 
1650 LogicalResult spirv::SpecConstantOp::verify() {
1651  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1652  if (specID.getValue().isNegative())
1653  return emitOpError("SpecId cannot be negative");
1654 
1655  auto value = getDefaultValue();
1656  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1657  // Make sure bitwidth is allowed.
1658  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1659  return emitOpError("default value bitwidth disallowed");
1660  return success();
1661  }
1662  return emitOpError(
1663  "default value can only be a bool, integer, or float scalar");
1664 }
1665 
1666 //===----------------------------------------------------------------------===//
1667 // spirv.VectorShuffle
1668 //===----------------------------------------------------------------------===//
1669 
1670 LogicalResult spirv::VectorShuffleOp::verify() {
1671  VectorType resultType = llvm::cast<VectorType>(getType());
1672 
1673  size_t numResultElements = resultType.getNumElements();
1674  if (numResultElements != getComponents().size())
1675  return emitOpError("result type element count (")
1676  << numResultElements
1677  << ") mismatch with the number of component selectors ("
1678  << getComponents().size() << ")";
1679 
1680  size_t totalSrcElements =
1681  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1682  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1683 
1684  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1685  uint32_t index = selector.getZExtValue();
1686  if (index >= totalSrcElements &&
1687  index != std::numeric_limits<uint32_t>().max())
1688  return emitOpError("component selector ")
1689  << index << " out of range: expected to be in [0, "
1690  << totalSrcElements << ") or 0xffffffff";
1691  }
1692  return success();
1693 }
1694 
1695 //===----------------------------------------------------------------------===//
1696 // spirv.MatrixTimesScalar
1697 //===----------------------------------------------------------------------===//
1698 
1699 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1700  Type elementType =
1701  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1703  [](auto matrixType) { return matrixType.getElementType(); })
1704  .Default([](Type) { return nullptr; });
1705 
1706  assert(elementType && "Unhandled type");
1707 
1708  // Check that the scalar type is the same as the matrix element type.
1709  if (getScalar().getType() != elementType)
1710  return emitOpError("input matrix components' type and scaling value must "
1711  "have the same type");
1712 
1713  return success();
1714 }
1715 
1716 //===----------------------------------------------------------------------===//
1717 // spirv.Transpose
1718 //===----------------------------------------------------------------------===//
1719 
1720 LogicalResult spirv::TransposeOp::verify() {
1721  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1722  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1723 
1724  // Verify that the input and output matrices have correct shapes.
1725  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1726  return emitError("input matrix rows count must be equal to "
1727  "output matrix columns count");
1728 
1729  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1730  return emitError("input matrix columns count must be equal to "
1731  "output matrix rows count");
1732 
1733  // Verify that the input and output matrices have the same component type
1734  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1735  return emitError("input and output matrices must have the same "
1736  "component type");
1737 
1738  return success();
1739 }
1740 
1741 //===----------------------------------------------------------------------===//
1742 // spirv.MatrixTimesVector
1743 //===----------------------------------------------------------------------===//
1744 
1745 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1746  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1747  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1748  auto resultType = llvm::cast<VectorType>(getType());
1749 
1750  if (matrixType.getNumColumns() != vectorType.getNumElements())
1751  return emitOpError("matrix columns (")
1752  << matrixType.getNumColumns() << ") must match vector operand size ("
1753  << vectorType.getNumElements() << ")";
1754 
1755  if (resultType.getNumElements() != matrixType.getNumRows())
1756  return emitOpError("result size (")
1757  << resultType.getNumElements() << ") must match the matrix rows ("
1758  << matrixType.getNumRows() << ")";
1759 
1760  if (matrixType.getElementType() != resultType.getElementType())
1761  return emitOpError("matrix and result element types must match");
1762 
1763  return success();
1764 }
1765 
1766 //===----------------------------------------------------------------------===//
1767 // spirv.VectorTimesMatrix
1768 //===----------------------------------------------------------------------===//
1769 
1770 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1771  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1772  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1773  auto resultType = llvm::cast<VectorType>(getType());
1774 
1775  if (matrixType.getNumRows() != vectorType.getNumElements())
1776  return emitOpError("number of components in vector must equal the number "
1777  "of components in each column in matrix");
1778 
1779  if (resultType.getNumElements() != matrixType.getNumColumns())
1780  return emitOpError("number of columns in matrix must equal the number of "
1781  "components in result");
1782 
1783  if (matrixType.getElementType() != resultType.getElementType())
1784  return emitOpError("matrix must be a matrix with the same component type "
1785  "as the component type in result");
1786 
1787  return success();
1788 }
1789 
1790 //===----------------------------------------------------------------------===//
1791 // spirv.MatrixTimesMatrix
1792 //===----------------------------------------------------------------------===//
1793 
1794 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1795  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1796  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1797  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1798 
1799  // left matrix columns' count and right matrix rows' count must be equal
1800  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1801  return emitError("left matrix columns' count must be equal to "
1802  "the right matrix rows' count");
1803 
1804  // right and result matrices columns' count must be the same
1805  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1806  return emitError(
1807  "right and result matrices must have equal columns' count");
1808 
1809  // right and result matrices component type must be the same
1810  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1811  return emitError("right and result matrices' component type must"
1812  " be the same");
1813 
1814  // left and result matrices component type must be the same
1815  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1816  return emitError("left and result matrices' component type"
1817  " must be the same");
1818 
1819  // left and result matrices rows count must be the same
1820  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1821  return emitError("left and result matrices must have equal rows' count");
1822 
1823  return success();
1824 }
1825 
1826 //===----------------------------------------------------------------------===//
1827 // spirv.SpecConstantComposite
1828 //===----------------------------------------------------------------------===//
1829 
1831  OperationState &result) {
1832 
1833  StringAttr compositeName;
1834  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1835  result.attributes))
1836  return failure();
1837 
1838  if (parser.parseLParen())
1839  return failure();
1840 
1841  SmallVector<Attribute, 4> constituents;
1842 
1843  do {
1844  // The name of the constituent attribute isn't important
1845  const char *attrName = "spec_const";
1846  FlatSymbolRefAttr specConstRef;
1847  NamedAttrList attrs;
1848 
1849  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1850  return failure();
1851 
1852  constituents.push_back(specConstRef);
1853  } while (!parser.parseOptionalComma());
1854 
1855  if (parser.parseRParen())
1856  return failure();
1857 
1858  StringAttr compositeSpecConstituentsName =
1859  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1860  result.addAttribute(compositeSpecConstituentsName,
1861  parser.getBuilder().getArrayAttr(constituents));
1862 
1863  Type type;
1864  if (parser.parseColonType(type))
1865  return failure();
1866 
1867  StringAttr typeAttrName =
1868  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1869  result.addAttribute(typeAttrName, TypeAttr::get(type));
1870 
1871  return success();
1872 }
1873 
1875  printer << " ";
1876  printer.printSymbolName(getSymName());
1877  printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1878  << ") : " << getType();
1879 }
1880 
1882  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1883  auto constituents = this->getConstituents().getValue();
1884 
1885  if (!cType)
1886  return emitError("result type must be a composite type, but provided ")
1887  << getType();
1888 
1889  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1890  return emitError("unsupported composite type ") << cType;
1891  if (constituents.size() != cType.getNumElements())
1892  return emitError("has incorrect number of operands: expected ")
1893  << cType.getNumElements() << ", but provided "
1894  << constituents.size();
1895 
1896  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1897  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1898 
1899  auto constituentSpecConstOp =
1900  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1901  (*this)->getParentOp(), constituent.getAttr()));
1902 
1903  if (constituentSpecConstOp.getDefaultValue().getType() !=
1904  cType.getElementType(index))
1905  return emitError("has incorrect types of operands: expected ")
1906  << cType.getElementType(index) << ", but provided "
1907  << constituentSpecConstOp.getDefaultValue().getType();
1908  }
1909 
1910  return success();
1911 }
1912 
1913 //===----------------------------------------------------------------------===//
1914 // spirv.EXTSpecConstantCompositeReplicateOp
1915 //===----------------------------------------------------------------------===//
1916 
1917 ParseResult
1919  OperationState &result) {
1920  StringAttr compositeName;
1921  FlatSymbolRefAttr specConstRef;
1922  const char *attrName = "spec_const";
1923  NamedAttrList attrs;
1924  Type type;
1925 
1926  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1927  result.attributes) ||
1928  parser.parseLParen() ||
1929  parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1930  parser.parseRParen() || parser.parseColonType(type))
1931  return failure();
1932 
1933  StringAttr compositeSpecConstituentName =
1934  spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1935  result.name);
1936  result.addAttribute(compositeSpecConstituentName, specConstRef);
1937 
1938  StringAttr typeAttrName =
1939  spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1940  result.addAttribute(typeAttrName, TypeAttr::get(type));
1941 
1942  return success();
1943 }
1944 
1946  printer << " ";
1947  printer.printSymbolName(getSymName());
1948  printer << " (" << this->getConstituent() << ") : " << getType();
1949 }
1950 
1952  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1953  if (!compositeType)
1954  return emitError("result type must be a composite type, but provided ")
1955  << getType();
1956 
1958  (*this)->getParentOp(), this->getConstituent());
1959  if (!constituentOp)
1960  return emitError(
1961  "splat spec constant reference defining constituent not found");
1962 
1963  auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1964  if (!constituentSpecConstOp)
1965  return emitError("constituent is not a spec constant");
1966 
1967  Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1968  Type compositeElementType = compositeType.getElementType(0);
1969  if (constituentType != compositeElementType)
1970  return emitError("constituent has incorrect type: expected ")
1971  << compositeElementType << ", but provided " << constituentType;
1972 
1973  return success();
1974 }
1975 
1976 //===----------------------------------------------------------------------===//
1977 // spirv.SpecConstantOperation
1978 //===----------------------------------------------------------------------===//
1979 
1981  OperationState &result) {
1982  Region *body = result.addRegion();
1983 
1984  if (parser.parseKeyword("wraps"))
1985  return failure();
1986 
1987  body->push_back(new Block);
1988  Block &block = body->back();
1989  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1990 
1991  if (!wrappedOp)
1992  return failure();
1993 
1994  OpBuilder builder(parser.getContext());
1995  builder.setInsertionPointToEnd(&block);
1996  spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
1997  result.location = wrappedOp->getLoc();
1998 
1999  result.addTypes(wrappedOp->getResult(0).getType());
2000 
2001  if (parser.parseOptionalAttrDict(result.attributes))
2002  return failure();
2003 
2004  return success();
2005 }
2006 
2008  printer << " wraps ";
2009  printer.printGenericOp(&getBody().front().front());
2010 }
2011 
2012 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2013  Block &block = getRegion().getBlocks().front();
2014 
2015  if (block.getOperations().size() != 2)
2016  return emitOpError("expected exactly 2 nested ops");
2017 
2018  Operation &enclosedOp = block.getOperations().front();
2019 
2021  return emitOpError("invalid enclosed op");
2022 
2023  for (auto operand : enclosedOp.getOperands())
2024  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2025  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2026  return emitOpError(
2027  "invalid operand, must be defined by a constant operation");
2028 
2029  return success();
2030 }
2031 
2032 //===----------------------------------------------------------------------===//
2033 // spirv.GL.FrexpStruct
2034 //===----------------------------------------------------------------------===//
2035 
2036 LogicalResult spirv::GLFrexpStructOp::verify() {
2037  spirv::StructType structTy =
2038  llvm::dyn_cast<spirv::StructType>(getResult().getType());
2039 
2040  if (structTy.getNumElements() != 2)
2041  return emitError("result type must be a struct type with two memebers");
2042 
2043  Type significandTy = structTy.getElementType(0);
2044  Type exponentTy = structTy.getElementType(1);
2045  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
2046  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
2047 
2048  Type operandTy = getOperand().getType();
2049  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
2050  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
2051 
2052  if (significandTy != operandTy)
2053  return emitError("member zero of the resulting struct type must be the "
2054  "same type as the operand");
2055 
2056  if (exponentVecTy) {
2057  IntegerType componentIntTy =
2058  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
2059  if (!componentIntTy || componentIntTy.getWidth() != 32)
2060  return emitError("member one of the resulting struct type must"
2061  "be a scalar or vector of 32 bit integer type");
2062  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2063  return emitError("member one of the resulting struct type "
2064  "must be a scalar or vector of 32 bit integer type");
2065  }
2066 
2067  // Check that the two member types have the same number of components
2068  if (operandVecTy && exponentVecTy &&
2069  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2070  return success();
2071 
2072  if (operandFTy && exponentIntTy)
2073  return success();
2074 
2075  return emitError("member one of the resulting struct type must have the same "
2076  "number of components as the operand type");
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // spirv.GL.Ldexp
2081 //===----------------------------------------------------------------------===//
2082 
2083 LogicalResult spirv::GLLdexpOp::verify() {
2084  Type significandType = getX().getType();
2085  Type exponentType = getExp().getType();
2086 
2087  if (llvm::isa<FloatType>(significandType) !=
2088  llvm::isa<IntegerType>(exponentType))
2089  return emitOpError("operands must both be scalars or vectors");
2090 
2091  auto getNumElements = [](Type type) -> unsigned {
2092  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2093  return vectorType.getNumElements();
2094  return 1;
2095  };
2096 
2097  if (getNumElements(significandType) != getNumElements(exponentType))
2098  return emitOpError("operands must have the same number of elements");
2099 
2100  return success();
2101 }
2102 
2103 //===----------------------------------------------------------------------===//
2104 // spirv.ShiftLeftLogicalOp
2105 //===----------------------------------------------------------------------===//
2106 
2107 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2108  return verifyShiftOp(*this);
2109 }
2110 
2111 //===----------------------------------------------------------------------===//
2112 // spirv.ShiftRightArithmeticOp
2113 //===----------------------------------------------------------------------===//
2114 
2115 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2116  return verifyShiftOp(*this);
2117 }
2118 
2119 //===----------------------------------------------------------------------===//
2120 // spirv.ShiftRightLogicalOp
2121 //===----------------------------------------------------------------------===//
2122 
2123 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2124  return verifyShiftOp(*this);
2125 }
2126 
2127 //===----------------------------------------------------------------------===//
2128 // spirv.VectorTimesScalarOp
2129 //===----------------------------------------------------------------------===//
2130 
2131 LogicalResult spirv::VectorTimesScalarOp::verify() {
2132  if (getVector().getType() != getType())
2133  return emitOpError("vector operand and result type mismatch");
2134  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2135  if (getScalar().getType() != scalarType)
2136  return emitOpError("scalar operand and result element type match");
2137  return success();
2138 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:267
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:563
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:253
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:299
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:186
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:291
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:275
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:158
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:454
SPIR-V struct type.
Definition: SPIRVTypes.h:251
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
static Type getValueType(Attribute attr)
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.