MLIR  22.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/TypeUtilities.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/InterleavedRange.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 
42 using namespace mlir;
43 using namespace mlir::spirv::AttrNames;
44 
45 //===----------------------------------------------------------------------===//
46 // Common utility functions
47 //===----------------------------------------------------------------------===//
48 
49 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51  if (!constOp) {
52  return failure();
53  }
54  auto valueAttr = constOp.getValue();
55  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
56  if (!integerValueAttr) {
57  return failure();
58  }
59 
60  if (integerValueAttr.getType().isSignlessInteger())
61  value = integerValueAttr.getInt();
62  else
63  value = integerValueAttr.getSInt();
64 
65  return success();
66 }
67 
68 LogicalResult
70  spirv::MemorySemantics memorySemantics) {
71  // According to the SPIR-V specification:
72  // "Despite being a mask and allowing multiple bits to be combined, it is
73  // invalid for more than one of these four bits to be set: Acquire, Release,
74  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75  // Release semantics is done by setting the AcquireRelease bit, not by setting
76  // two bits."
77  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78  spirv::MemorySemantics::Release |
79  spirv::MemorySemantics::AcquireRelease |
80  spirv::MemorySemantics::SequentiallyConsistent;
81 
82  auto bitCount =
83  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84  if (bitCount > 1) {
85  return op->emitError(
86  "expected at most one of these four memory constraints "
87  "to be set: `Acquire`, `Release`,"
88  "`AcquireRelease` or `SequentiallyConsistent`");
89  }
90  return success();
91 }
92 
94  SmallVectorImpl<StringRef> &elidedAttrs) {
95  // Print optional descriptor binding
96  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97  stringifyDecoration(spirv::Decoration::DescriptorSet));
98  auto bindingName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::Binding));
100  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102  if (descriptorSet && binding) {
103  elidedAttrs.push_back(descriptorSetName);
104  elidedAttrs.push_back(bindingName);
105  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106  << ")";
107  }
108 
109  // Print BuiltIn attribute if present
110  auto builtInName = llvm::convertToSnakeFromCamelCase(
111  stringifyDecoration(spirv::Decoration::BuiltIn));
112  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114  elidedAttrs.push_back(builtInName);
115  }
116 
117  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118 }
119 
120 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
121  OperationState &result) {
123  Type type;
124  // If the operand list is in-between parentheses, then we have a generic form.
125  // (see the fallback in `printOneResultOp`).
126  SMLoc loc = parser.getCurrentLocation();
127  if (!parser.parseOptionalLParen()) {
128  if (parser.parseOperandList(ops) || parser.parseRParen() ||
129  parser.parseOptionalAttrDict(result.attributes) ||
130  parser.parseColon() || parser.parseType(type))
131  return failure();
132  auto fnType = llvm::dyn_cast<FunctionType>(type);
133  if (!fnType) {
134  parser.emitError(loc, "expected function type");
135  return failure();
136  }
137  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138  return failure();
139  result.addTypes(fnType.getResults());
140  return success();
141  }
142  return failure(parser.parseOperandList(ops) ||
143  parser.parseOptionalAttrDict(result.attributes) ||
144  parser.parseColonType(type) ||
145  parser.resolveOperands(ops, type, result.operands) ||
146  parser.addTypeToList(type, result.types));
147 }
148 
150  assert(op->getNumResults() == 1 && "op should have one result");
151 
152  // If not all the operand and result types are the same, just use the
153  // generic assembly form to avoid omitting information in printing.
154  auto resultType = op->getResult(0).getType();
155  if (llvm::any_of(op->getOperandTypes(),
156  [&](Type type) { return type != resultType; })) {
157  p.printGenericOp(op, /*printOpName=*/false);
158  return;
159  }
160 
161  p << ' ';
162  p.printOperands(op->getOperands());
164  // Now we can output only one type for all operands and the result.
165  p << " : " << resultType;
166 }
167 
168 template <typename BlockReadWriteOpTy>
169 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170  Value ptr, Value val) {
171  auto valType = val.getType();
172  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
173  valType = valVecTy.getElementType();
174 
175  if (valType !=
176  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
177  return op.emitOpError("mismatch in result type and pointer type");
178  }
179  return success();
180 }
181 
182 /// Walks the given type hierarchy with the given indices, potentially down
183 /// to component granularity, to select an element type. Returns null type and
184 /// emits errors with the given loc on failure.
185 static Type
187  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
188  if (indices.empty()) {
189  emitErrorFn("expected at least one index for spirv.CompositeExtract");
190  return nullptr;
191  }
192 
193  for (auto index : indices) {
194  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
195  if (cType.hasCompileTimeKnownNumElements() &&
196  (index < 0 ||
197  static_cast<uint64_t>(index) >= cType.getNumElements())) {
198  emitErrorFn("index ") << index << " out of bounds for " << type;
199  return nullptr;
200  }
201  type = cType.getElementType(index);
202  } else {
203  emitErrorFn("cannot extract from non-composite type ")
204  << type << " with index " << index;
205  return nullptr;
206  }
207  }
208  return type;
209 }
210 
211 static Type
213  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
214  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
215  if (!indicesArrayAttr) {
216  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
217  return nullptr;
218  }
219  if (indicesArrayAttr.empty()) {
220  emitErrorFn("expected at least one index for spirv.CompositeExtract");
221  return nullptr;
222  }
223 
224  SmallVector<int32_t, 2> indexVals;
225  for (auto indexAttr : indicesArrayAttr) {
226  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
227  if (!indexIntAttr) {
228  emitErrorFn("expected an 32-bit integer for index, but found '")
229  << indexAttr << "'";
230  return nullptr;
231  }
232  indexVals.push_back(indexIntAttr.getInt());
233  }
234  return getElementType(type, indexVals, emitErrorFn);
235 }
236 
237 static Type getElementType(Type type, Attribute indices, Location loc) {
238  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
239  return ::mlir::emitError(loc, err);
240  };
241  return getElementType(type, indices, errorFn);
242 }
243 
244 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
245  SMLoc loc) {
246  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
247  return parser.emitError(loc, err);
248  };
249  return getElementType(type, indices, errorFn);
250 }
251 
252 template <typename ExtendedBinaryOp>
253 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
254  auto resultType = llvm::cast<spirv::StructType>(op.getType());
255  if (resultType.getNumElements() != 2)
256  return op.emitOpError("expected result struct type containing two members");
257 
258  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
259  resultType.getElementType(0),
260  resultType.getElementType(1)}))
261  return op.emitOpError(
262  "expected all operand types and struct member types are the same");
263 
264  return success();
265 }
266 
267 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
268  OperationState &result) {
270  if (parser.parseOptionalAttrDict(result.attributes) ||
271  parser.parseOperandList(operands) || parser.parseColon())
272  return failure();
273 
274  Type resultType;
275  SMLoc loc = parser.getCurrentLocation();
276  if (parser.parseType(resultType))
277  return failure();
278 
279  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
280  if (!structType || structType.getNumElements() != 2)
281  return parser.emitError(loc, "expected spirv.struct type with two members");
282 
283  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
284  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
285  return failure();
286 
287  result.addTypes(resultType);
288  return success();
289 }
290 
292  OpAsmPrinter &printer) {
293  printer << ' ';
294  printer.printOptionalAttrDict(op->getAttrs());
295  printer.printOperands(op->getOperands());
296  printer << " : " << op->getResultTypes().front();
297 }
298 
299 static LogicalResult verifyShiftOp(Operation *op) {
300  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
301  return op->emitError("expected the same type for the first operand and "
302  "result, but provided ")
303  << op->getOperand(0).getType() << " and "
304  << op->getResult(0).getType();
305  }
306  return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // spirv.mlir.addressof
311 //===----------------------------------------------------------------------===//
312 
313 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
314  spirv::GlobalVariableOp var) {
315  build(builder, state, var.getType(), SymbolRefAttr::get(var));
316 }
317 
318 LogicalResult spirv::AddressOfOp::verify() {
319  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
320  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
321  getVariableAttr()));
322  if (!varOp) {
323  return emitOpError("expected spirv.GlobalVariable symbol");
324  }
325  if (getPointer().getType() != varOp.getType()) {
326  return emitOpError(
327  "result type mismatch with the referenced global variable's type");
328  }
329  return success();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // spirv.CompositeConstruct
334 //===----------------------------------------------------------------------===//
335 
336 LogicalResult spirv::CompositeConstructOp::verify() {
337  operand_range constituents = this->getConstituents();
338 
339  // There are 4 cases with varying verification rules:
340  // 1. Cooperative Matrices (1 constituent)
341  // 2. Structs (1 constituent for each member)
342  // 3. Arrays (1 constituent for each array element)
343  // 4. Vectors (1 constituent (sub-)element for each vector element)
344 
345  auto coopElementType =
348  [](auto coopType) { return coopType.getElementType(); })
349  .Default([](Type) { return nullptr; });
350 
351  // Case 1. -- matrices.
352  if (coopElementType) {
353  if (constituents.size() != 1)
354  return emitOpError("has incorrect number of operands: expected ")
355  << "1, but provided " << constituents.size();
356  if (coopElementType != constituents.front().getType())
357  return emitOpError("operand type mismatch: expected operand type ")
358  << coopElementType << ", but provided "
359  << constituents.front().getType();
360  return success();
361  }
362 
363  // Case 2./3./4. -- number of constituents matches the number of elements.
364  auto cType = llvm::cast<spirv::CompositeType>(getType());
365  if (constituents.size() == cType.getNumElements()) {
366  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367  if (constituents[index].getType() != cType.getElementType(index)) {
368  return emitOpError("operand type mismatch: expected operand type ")
369  << cType.getElementType(index) << ", but provided "
370  << constituents[index].getType();
371  }
372  }
373  return success();
374  }
375 
376  // Case 4. -- check that all constituents add up tp the expected vector type.
377  auto resultType = llvm::dyn_cast<VectorType>(cType);
378  if (!resultType)
379  return emitOpError(
380  "expected to return a vector or cooperative matrix when the number of "
381  "constituents is less than what the result needs");
382 
383  SmallVector<unsigned> sizes;
384  for (Value component : constituents) {
385  if (!llvm::isa<VectorType>(component.getType()) &&
386  !component.getType().isIntOrFloat())
387  return emitOpError("operand type mismatch: expected operand to have "
388  "a scalar or vector type, but provided ")
389  << component.getType();
390 
391  Type elementType = component.getType();
392  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
393  sizes.push_back(vectorType.getNumElements());
394  elementType = vectorType.getElementType();
395  } else {
396  sizes.push_back(1);
397  }
398 
399  if (elementType != resultType.getElementType())
400  return emitOpError("operand element type mismatch: expected to be ")
401  << resultType.getElementType() << ", but provided " << elementType;
402  }
403  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
404  if (totalCount != cType.getNumElements())
405  return emitOpError("has incorrect number of operands: expected ")
406  << cType.getNumElements() << ", but provided " << totalCount;
407  return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // spirv.CompositeExtractOp
412 //===----------------------------------------------------------------------===//
413 
414 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
415  Value composite,
416  ArrayRef<int32_t> indices) {
417  auto indexAttr = builder.getI32ArrayAttr(indices);
418  auto elementType =
419  getElementType(composite.getType(), indexAttr, state.location);
420  if (!elementType) {
421  return;
422  }
423  build(builder, state, elementType, composite, indexAttr);
424 }
425 
427  OperationState &result) {
428  OpAsmParser::UnresolvedOperand compositeInfo;
429  Attribute indicesAttr;
430  StringRef indicesAttrName =
431  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
432  Type compositeType;
433  SMLoc attrLocation;
434 
435  if (parser.parseOperand(compositeInfo) ||
436  parser.getCurrentLocation(&attrLocation) ||
437  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
438  parser.parseColonType(compositeType) ||
439  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
440  return failure();
441  }
442 
443  Type resultType =
444  getElementType(compositeType, indicesAttr, parser, attrLocation);
445  if (!resultType) {
446  return failure();
447  }
448  result.addTypes(resultType);
449  return success();
450 }
451 
453  printer << ' ' << getComposite() << getIndices() << " : "
454  << getComposite().getType();
455 }
456 
457 LogicalResult spirv::CompositeExtractOp::verify() {
458  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
459  auto resultType =
460  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
461  if (!resultType)
462  return failure();
463 
464  if (resultType != getType()) {
465  return emitOpError("invalid result type: expected ")
466  << resultType << " but provided " << getType();
467  }
468 
469  return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // spirv.CompositeInsert
474 //===----------------------------------------------------------------------===//
475 
476 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
477  Value object, Value composite,
478  ArrayRef<int32_t> indices) {
479  auto indexAttr = builder.getI32ArrayAttr(indices);
480  build(builder, state, composite.getType(), object, composite, indexAttr);
481 }
482 
483 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
484  OperationState &result) {
486  Type objectType, compositeType;
487  Attribute indicesAttr;
488  StringRef indicesAttrName =
489  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
490  auto loc = parser.getCurrentLocation();
491 
492  return failure(
493  parser.parseOperandList(operands, 2) ||
494  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
495  parser.parseColonType(objectType) ||
496  parser.parseKeywordType("into", compositeType) ||
497  parser.resolveOperands(operands, {objectType, compositeType}, loc,
498  result.operands) ||
499  parser.addTypesToList(compositeType, result.types));
500 }
501 
502 LogicalResult spirv::CompositeInsertOp::verify() {
503  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
504  auto objectType =
505  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
506  if (!objectType)
507  return failure();
508 
509  if (objectType != getObject().getType()) {
510  return emitOpError("object operand type should be ")
511  << objectType << ", but found " << getObject().getType();
512  }
513 
514  if (getComposite().getType() != getType()) {
515  return emitOpError("result type should be the same as "
516  "the composite type, but found ")
517  << getComposite().getType() << " vs " << getType();
518  }
519 
520  return success();
521 }
522 
524  printer << " " << getObject() << ", " << getComposite() << getIndices()
525  << " : " << getObject().getType() << " into "
526  << getComposite().getType();
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // spirv.Constant
531 //===----------------------------------------------------------------------===//
532 
533 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534  OperationState &result) {
535  Attribute value;
536  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
537  if (parser.parseAttribute(value, valueAttrName, result.attributes))
538  return failure();
539 
540  Type type = NoneType::get(parser.getContext());
541  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
542  type = typedAttr.getType();
543  if (llvm::isa<NoneType, TensorType>(type)) {
544  if (parser.parseColonType(type))
545  return failure();
546  }
547 
548  if (llvm::isa<TensorArmType>(type)) {
549  if (parser.parseOptionalColon().succeeded())
550  if (parser.parseType(type))
551  return failure();
552  }
553 
554  return parser.addTypeToList(type, result.types);
555 }
556 
557 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
558  printer << ' ' << getValue();
559  if (llvm::isa<spirv::ArrayType>(getType()))
560  printer << " : " << getType();
561 }
562 
563 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
564  Type opType) {
565  if (isa<spirv::CooperativeMatrixType>(opType)) {
566  auto denseAttr = dyn_cast<DenseElementsAttr>(value);
567  if (!denseAttr || !denseAttr.isSplat())
568  return op.emitOpError("expected a splat dense attribute for cooperative "
569  "matrix constant, but found ")
570  << denseAttr;
571  }
572  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
573  auto valueType = llvm::cast<TypedAttr>(value).getType();
574  if (valueType != opType)
575  return op.emitOpError("result type (")
576  << opType << ") does not match value type (" << valueType << ")";
577  return success();
578  }
579  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
580  auto valueType = llvm::cast<TypedAttr>(value).getType();
581  if (valueType == opType)
582  return success();
583  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
584  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
585  if (!arrayType)
586  return op.emitOpError("result or element type (")
587  << opType << ") does not match value type (" << valueType
588  << "), must be the same or spirv.array";
589 
590  int numElements = arrayType.getNumElements();
591  auto opElemType = arrayType.getElementType();
592  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
593  numElements *= t.getNumElements();
594  opElemType = t.getElementType();
595  }
596  if (!opElemType.isIntOrFloat())
597  return op.emitOpError("only support nested array result type");
598 
599  auto valueElemType = shapedType.getElementType();
600  if (valueElemType != opElemType) {
601  return op.emitOpError("result element type (")
602  << opElemType << ") does not match value element type ("
603  << valueElemType << ")";
604  }
605 
606  if (numElements != shapedType.getNumElements()) {
607  return op.emitOpError("result number of elements (")
608  << numElements << ") does not match value number of elements ("
609  << shapedType.getNumElements() << ")";
610  }
611  return success();
612  }
613  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
614  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
615  if (!arrayType)
616  return op.emitOpError(
617  "must have spirv.array result type for array value");
618  Type elemType = arrayType.getElementType();
619  for (Attribute element : arrayAttr.getValue()) {
620  // Verify array elements recursively.
621  if (failed(verifyConstantType(op, element, elemType)))
622  return failure();
623  }
624  return success();
625  }
626  return op.emitOpError("cannot have attribute: ") << value;
627 }
628 
629 LogicalResult spirv::ConstantOp::verify() {
630  // ODS already generates checks to make sure the result type is valid. We just
631  // need to additionally check that the value's attribute type is consistent
632  // with the result type.
633  return verifyConstantType(*this, getValueAttr(), getType());
634 }
635 
636 bool spirv::ConstantOp::isBuildableWith(Type type) {
637  // Must be valid SPIR-V type first.
638  if (!llvm::isa<spirv::SPIRVType>(type))
639  return false;
640 
641  if (isa<SPIRVDialect>(type.getDialect())) {
642  // TODO: support constant struct
643  return llvm::isa<spirv::ArrayType>(type);
644  }
645 
646  return true;
647 }
648 
649 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
650  OpBuilder &builder) {
651  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
652  unsigned width = intType.getWidth();
653  if (width == 1)
654  return spirv::ConstantOp::create(builder, loc, type,
655  builder.getBoolAttr(false));
656  return spirv::ConstantOp::create(
657  builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
658  }
659  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
660  return spirv::ConstantOp::create(builder, loc, type,
661  builder.getFloatAttr(floatType, 0.0));
662  }
663  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
664  Type elemType = vectorType.getElementType();
665  if (llvm::isa<IntegerType>(elemType)) {
666  return spirv::ConstantOp::create(
667  builder, loc, type,
668  DenseElementsAttr::get(vectorType,
669  IntegerAttr::get(elemType, 0).getValue()));
670  }
671  if (llvm::isa<FloatType>(elemType)) {
672  return spirv::ConstantOp::create(
673  builder, loc, type,
674  DenseFPElementsAttr::get(vectorType,
675  FloatAttr::get(elemType, 0.0).getValue()));
676  }
677  }
678 
679  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
680 }
681 
682 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
683  OpBuilder &builder) {
684  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
685  unsigned width = intType.getWidth();
686  if (width == 1)
687  return spirv::ConstantOp::create(builder, loc, type,
688  builder.getBoolAttr(true));
689  return spirv::ConstantOp::create(
690  builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
691  }
692  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
693  return spirv::ConstantOp::create(builder, loc, type,
694  builder.getFloatAttr(floatType, 1.0));
695  }
696  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
697  Type elemType = vectorType.getElementType();
698  if (llvm::isa<IntegerType>(elemType)) {
699  return spirv::ConstantOp::create(
700  builder, loc, type,
701  DenseElementsAttr::get(vectorType,
702  IntegerAttr::get(elemType, 1).getValue()));
703  }
704  if (llvm::isa<FloatType>(elemType)) {
705  return spirv::ConstantOp::create(
706  builder, loc, type,
707  DenseFPElementsAttr::get(vectorType,
708  FloatAttr::get(elemType, 1.0).getValue()));
709  }
710  }
711 
712  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
713 }
714 
715 void mlir::spirv::ConstantOp::getAsmResultNames(
716  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
717  Type type = getType();
718 
719  SmallString<32> specialNameBuffer;
720  llvm::raw_svector_ostream specialName(specialNameBuffer);
721  specialName << "cst";
722 
723  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
724 
725  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
726  if (intTy && intTy.getWidth() == 1) {
727  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
728  }
729 
730  if (intTy.isSignless()) {
731  specialName << intCst.getInt();
732  } else if (intTy.isUnsigned()) {
733  specialName << intCst.getUInt();
734  } else {
735  specialName << intCst.getSInt();
736  }
737  }
738 
739  if (intTy || llvm::isa<FloatType>(type)) {
740  specialName << '_' << type;
741  }
742 
743  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
744  specialName << "_vec_";
745  specialName << vecType.getDimSize(0);
746 
747  Type elementType = vecType.getElementType();
748 
749  if (llvm::isa<IntegerType>(elementType) ||
750  llvm::isa<FloatType>(elementType)) {
751  specialName << "x" << elementType;
752  }
753  }
754 
755  setNameFn(getResult(), specialName.str());
756 }
757 
758 void mlir::spirv::AddressOfOp::getAsmResultNames(
759  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
760  SmallString<32> specialNameBuffer;
761  llvm::raw_svector_ostream specialName(specialNameBuffer);
762  specialName << getVariable() << "_addr";
763  setNameFn(getResult(), specialName.str());
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // spirv.EXTConstantCompositeReplicate
768 //===----------------------------------------------------------------------===//
769 
770 // Returns type of attribute. In case of a TypedAttr this will simply return
771 // the type. But for an ArrayAttr which is untyped and can be multidimensional
772 // it creates the ArrayType recursively.
774  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775  return typedAttr.getType();
776  }
777 
778  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
779  return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
780  }
781 
782  return nullptr;
783 }
784 
786  Type valueType = getValueType(getValue());
787  if (!valueType)
788  return emitError("unknown value attribute type");
789 
790  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
791  if (!compositeType)
792  return emitError("result type is not a composite type");
793 
794  Type compositeElementType = compositeType.getElementType(0);
795 
796  SmallVector<Type, 3> possibleTypes = {compositeElementType};
797  while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
798  compositeElementType = type.getElementType(0);
799  possibleTypes.push_back(compositeElementType);
800  }
801 
802  if (!is_contained(possibleTypes, valueType)) {
803  return emitError("expected value attribute type ")
804  << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
805  }
806 
807  return success();
808 }
809 
810 //===----------------------------------------------------------------------===//
811 // spirv.ControlBarrierOp
812 //===----------------------------------------------------------------------===//
813 
814 LogicalResult spirv::ControlBarrierOp::verify() {
815  return verifyMemorySemantics(getOperation(), getMemorySemantics());
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // spirv.EntryPoint
820 //===----------------------------------------------------------------------===//
821 
822 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
823  spirv::ExecutionModel executionModel,
824  spirv::FuncOp function,
825  ArrayRef<Attribute> interfaceVars) {
826  build(builder, state,
827  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
828  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
829 }
830 
831 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
832  OperationState &result) {
833  spirv::ExecutionModel execModel;
834  SmallVector<Attribute, 4> interfaceVars;
835 
837  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
838  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
839  return failure();
840  }
841 
842  if (!parser.parseOptionalComma()) {
843  // Parse the interface variables
844  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
845  // The name of the interface variable attribute isnt important
846  FlatSymbolRefAttr var;
847  NamedAttrList attrs;
848  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
849  return failure();
850  interfaceVars.push_back(var);
851  return success();
852  }))
853  return failure();
854  }
855  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
856  parser.getBuilder().getArrayAttr(interfaceVars));
857  return success();
858 }
859 
861  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
862  printer.printSymbolName(getFn());
863  auto interfaceVars = getInterface().getValue();
864  if (!interfaceVars.empty())
865  printer << ", " << llvm::interleaved(interfaceVars);
866 }
867 
868 LogicalResult spirv::EntryPointOp::verify() {
869  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
870  // verification.
871  return success();
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // spirv.ExecutionMode
876 //===----------------------------------------------------------------------===//
877 
878 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
879  spirv::FuncOp function,
880  spirv::ExecutionMode executionMode,
881  ArrayRef<int32_t> params) {
882  build(builder, state, SymbolRefAttr::get(function),
883  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
884  builder.getI32ArrayAttr(params));
885 }
886 
887 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
888  OperationState &result) {
889  spirv::ExecutionMode execMode;
890  Attribute fn;
891  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
892  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
893  return failure();
894  }
895 
897  Type i32Type = parser.getBuilder().getIntegerType(32);
898  while (!parser.parseOptionalComma()) {
899  NamedAttrList attr;
900  Attribute value;
901  if (parser.parseAttribute(value, i32Type, "value", attr)) {
902  return failure();
903  }
904  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
905  }
906  StringRef valuesAttrName =
907  spirv::ExecutionModeOp::getValuesAttrName(result.name);
908  result.addAttribute(valuesAttrName,
909  parser.getBuilder().getI32ArrayAttr(values));
910  return success();
911 }
912 
914  printer << " ";
915  printer.printSymbolName(getFn());
916  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
917  ArrayAttr values = this->getValues();
918  if (!values.empty())
919  printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // spirv.func
924 //===----------------------------------------------------------------------===//
925 
926 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
928  SmallVector<DictionaryAttr> resultAttrs;
929  SmallVector<Type> resultTypes;
930  auto &builder = parser.getBuilder();
931 
932  // Parse the name as a symbol.
933  StringAttr nameAttr;
934  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
935  result.attributes))
936  return failure();
937 
938  // Parse the function signature.
939  bool isVariadic = false;
941  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
942  resultAttrs))
943  return failure();
944 
945  SmallVector<Type> argTypes;
946  for (auto &arg : entryArgs)
947  argTypes.push_back(arg.type);
948  auto fnType = builder.getFunctionType(argTypes, resultTypes);
949  result.addAttribute(getFunctionTypeAttrName(result.name),
950  TypeAttr::get(fnType));
951 
952  // Parse the optional function control keyword.
953  spirv::FunctionControl fnControl;
954  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
955  return failure();
956 
957  // If additional attributes are present, parse them.
958  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
959  return failure();
960 
961  // Add the attributes to the function arguments.
962  assert(resultAttrs.size() == resultTypes.size());
964  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
965  getResAttrsAttrName(result.name));
966 
967  // Parse the optional function body.
968  auto *body = result.addRegion();
969  OptionalParseResult parseResult =
970  parser.parseOptionalRegion(*body, entryArgs);
971  return failure(parseResult.has_value() && failed(*parseResult));
972 }
973 
974 void spirv::FuncOp::print(OpAsmPrinter &printer) {
975  // Print function name, signature, and control.
976  printer << " ";
977  printer.printSymbolName(getSymName());
978  auto fnType = getFunctionType();
980  printer, *this, fnType.getInputs(),
981  /*isVariadic=*/false, fnType.getResults());
982  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
983  << "\"";
985  printer, *this,
986  {spirv::attributeName<spirv::FunctionControl>(),
987  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
988  getFunctionControlAttrName()});
989 
990  // Print the body if this is not an external function.
991  Region &body = this->getBody();
992  if (!body.empty()) {
993  printer << ' ';
994  printer.printRegion(body, /*printEntryBlockArgs=*/false,
995  /*printBlockTerminators=*/true);
996  }
997 }
998 
999 LogicalResult spirv::FuncOp::verifyType() {
1000  FunctionType fnType = getFunctionType();
1001  if (fnType.getNumResults() > 1)
1002  return emitOpError("cannot have more than one result");
1003 
1004  auto hasDecorationAttr = [&](spirv::Decoration decoration,
1005  unsigned argIndex) {
1006  auto func = llvm::cast<FunctionOpInterface>(getOperation());
1007  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1008  if (argAttr.getName() != spirv::DecorationAttr::name)
1009  continue;
1010  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1011  return decAttr.getValue() == decoration;
1012  }
1013  return false;
1014  };
1015 
1016  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1017  Type param = fnType.getInputs()[i];
1018  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1019  if (!inputPtrType)
1020  continue;
1021 
1022  auto pointeePtrType =
1023  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1024  if (pointeePtrType) {
1025  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1026  // > If an OpFunctionParameter is a pointer (or contains a pointer)
1027  // > and the type it points to is a pointer in the PhysicalStorageBuffer
1028  // > storage class, the function parameter must be decorated with exactly
1029  // > one of AliasedPointer or RestrictPointer.
1030  if (pointeePtrType.getStorageClass() !=
1031  spirv::StorageClass::PhysicalStorageBuffer)
1032  continue;
1033 
1034  bool hasAliasedPtr =
1035  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1036  bool hasRestrictPtr =
1037  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1038  if (!hasAliasedPtr && !hasRestrictPtr)
1039  return emitOpError()
1040  << "with a pointer points to a physical buffer pointer must "
1041  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1042  continue;
1043  }
1044  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1045  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1046  // > the PhysicalStorageBuffer storage class, the function parameter must
1047  // > be decorated with exactly one of Aliased or Restrict.
1048  if (auto pointeeArrayType =
1049  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1050  pointeePtrType =
1051  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1052  } else {
1053  pointeePtrType = inputPtrType;
1054  }
1055 
1056  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1057  spirv::StorageClass::PhysicalStorageBuffer)
1058  continue;
1059 
1060  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1061  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1062  if (!hasAliased && !hasRestrict)
1063  return emitOpError() << "with physical buffer pointer must be decorated "
1064  "either 'Aliased' or 'Restrict'";
1065  }
1066 
1067  return success();
1068 }
1069 
1070 LogicalResult spirv::FuncOp::verifyBody() {
1071  FunctionType fnType = getFunctionType();
1072  if (!isExternal()) {
1073  Block &entryBlock = front();
1074 
1075  unsigned numArguments = this->getNumArguments();
1076  if (entryBlock.getNumArguments() != numArguments)
1077  return emitOpError("entry block must have ")
1078  << numArguments << " arguments to match function signature";
1079 
1080  for (auto [index, fnArgType, blockArgType] :
1081  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1082  if (blockArgType != fnArgType) {
1083  return emitOpError("type of entry block argument #")
1084  << index << '(' << blockArgType
1085  << ") must match the type of the corresponding argument in "
1086  << "function signature(" << fnArgType << ')';
1087  }
1088  }
1089  }
1090 
1091  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1092  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1093  if (fnType.getNumResults() != 0)
1094  return retOp.emitOpError("cannot be used in functions returning value");
1095  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1096  if (fnType.getNumResults() != 1)
1097  return retOp.emitOpError(
1098  "returns 1 value but enclosing function requires ")
1099  << fnType.getNumResults() << " results";
1100 
1101  auto retOperandType = retOp.getValue().getType();
1102  auto fnResultType = fnType.getResult(0);
1103  if (retOperandType != fnResultType)
1104  return retOp.emitOpError(" return value's type (")
1105  << retOperandType << ") mismatch with function's result type ("
1106  << fnResultType << ")";
1107  }
1108  return WalkResult::advance();
1109  });
1110 
1111  // TODO: verify other bits like linkage type.
1112 
1113  return failure(walkResult.wasInterrupted());
1114 }
1115 
1116 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1117  StringRef name, FunctionType type,
1118  spirv::FunctionControl control,
1119  ArrayRef<NamedAttribute> attrs) {
1120  state.addAttribute(SymbolTable::getSymbolAttrName(),
1121  builder.getStringAttr(name));
1122  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1123  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1124  builder.getAttr<spirv::FunctionControlAttr>(control));
1125  state.attributes.append(attrs.begin(), attrs.end());
1126  state.addRegion();
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // spirv.GLFClampOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1134  OperationState &result) {
1135  return parseOneResultSameOperandTypeOp(parser, result);
1136 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // spirv.GLUClampOp
1141 //===----------------------------------------------------------------------===//
1142 
1143 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1144  OperationState &result) {
1145  return parseOneResultSameOperandTypeOp(parser, result);
1146 }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // spirv.GLSClampOp
1151 //===----------------------------------------------------------------------===//
1152 
1153 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1154  OperationState &result) {
1155  return parseOneResultSameOperandTypeOp(parser, result);
1156 }
1158 
1159 //===----------------------------------------------------------------------===//
1160 // spirv.GLFmaOp
1161 //===----------------------------------------------------------------------===//
1162 
1163 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1164  return parseOneResultSameOperandTypeOp(parser, result);
1165 }
1166 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // spirv.GlobalVariable
1170 //===----------------------------------------------------------------------===//
1171 
1172 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1173  Type type, StringRef name,
1174  unsigned descriptorSet, unsigned binding) {
1175  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1176  state.addAttribute(
1177  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1178  builder.getI32IntegerAttr(descriptorSet));
1179  state.addAttribute(
1180  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1181  builder.getI32IntegerAttr(binding));
1182 }
1183 
1184 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1185  Type type, StringRef name,
1186  spirv::BuiltIn builtin) {
1187  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1188  state.addAttribute(
1189  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1190  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1191 }
1192 
1193 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1194  OperationState &result) {
1195  // Parse variable name.
1196  StringAttr nameAttr;
1197  StringRef initializerAttrName =
1198  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1199  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1200  result.attributes)) {
1201  return failure();
1202  }
1203 
1204  // Parse optional initializer
1205  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1206  FlatSymbolRefAttr initSymbol;
1207  if (parser.parseLParen() ||
1208  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1209  result.attributes) ||
1210  parser.parseRParen())
1211  return failure();
1212  }
1213 
1214  if (parseVariableDecorations(parser, result)) {
1215  return failure();
1216  }
1217 
1218  Type type;
1219  StringRef typeAttrName =
1220  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1221  auto loc = parser.getCurrentLocation();
1222  if (parser.parseColonType(type)) {
1223  return failure();
1224  }
1225  if (!llvm::isa<spirv::PointerType>(type)) {
1226  return parser.emitError(loc, "expected spirv.ptr type");
1227  }
1228  result.addAttribute(typeAttrName, TypeAttr::get(type));
1229 
1230  return success();
1231 }
1232 
1234  SmallVector<StringRef, 4> elidedAttrs{
1235  spirv::attributeName<spirv::StorageClass>()};
1236 
1237  // Print variable name.
1238  printer << ' ';
1239  printer.printSymbolName(getSymName());
1240  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1241 
1242  StringRef initializerAttrName = this->getInitializerAttrName();
1243  // Print optional initializer
1244  if (auto initializer = this->getInitializer()) {
1245  printer << " " << initializerAttrName << '(';
1246  printer.printSymbolName(*initializer);
1247  printer << ')';
1248  elidedAttrs.push_back(initializerAttrName);
1249  }
1250 
1251  StringRef typeAttrName = this->getTypeAttrName();
1252  elidedAttrs.push_back(typeAttrName);
1253  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1254  printer << " : " << getType();
1255 }
1256 
1257 LogicalResult spirv::GlobalVariableOp::verify() {
1258  if (!llvm::isa<spirv::PointerType>(getType()))
1259  return emitOpError("result must be of a !spv.ptr type");
1260 
1261  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1262  // object. It cannot be Generic. It must be the same as the Storage Class
1263  // operand of the Result Type."
1264  // Also, Function storage class is reserved by spirv.Variable.
1265  auto storageClass = this->storageClass();
1266  if (storageClass == spirv::StorageClass::Generic ||
1267  storageClass == spirv::StorageClass::Function) {
1268  return emitOpError("storage class cannot be '")
1269  << stringifyStorageClass(storageClass) << "'";
1270  }
1271 
1272  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1273  this->getInitializerAttrName())) {
1275  (*this)->getParentOp(), init.getAttr());
1276  // TODO: Currently only variable initialization with specialization
1277  // constants and other variables is supported. They could be normal
1278  // constants in the module scope as well.
1279  if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1280  spirv::SpecConstantCompositeOp>(initOp)) {
1281  return emitOpError("initializer must be result of a "
1282  "spirv.SpecConstant or spirv.GlobalVariable or "
1283  "spirv.SpecConstantCompositeOp op");
1284  }
1285  }
1286 
1287  return success();
1288 }
1289 
1290 //===----------------------------------------------------------------------===//
1291 // spirv.INTEL.SubgroupBlockRead
1292 //===----------------------------------------------------------------------===//
1293 
1295  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1296  return failure();
1297 
1298  return success();
1299 }
1300 
1301 //===----------------------------------------------------------------------===//
1302 // spirv.INTEL.SubgroupBlockWrite
1303 //===----------------------------------------------------------------------===//
1304 
1306  OperationState &result) {
1307  // Parse the storage class specification
1308  spirv::StorageClass storageClass;
1310  auto loc = parser.getCurrentLocation();
1311  Type elementType;
1312  if (parseEnumStrAttr(storageClass, parser) ||
1313  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1314  parser.parseType(elementType)) {
1315  return failure();
1316  }
1317 
1318  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1319  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1320  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1321 
1322  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1323  result.operands)) {
1324  return failure();
1325  }
1326  return success();
1327 }
1328 
1330  printer << " " << getPtr() << ", " << getValue() << " : "
1331  << getValue().getType();
1332 }
1333 
1335  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1336  return failure();
1337 
1338  return success();
1339 }
1340 
1341 //===----------------------------------------------------------------------===//
1342 // spirv.IAddCarryOp
1343 //===----------------------------------------------------------------------===//
1344 
1345 LogicalResult spirv::IAddCarryOp::verify() {
1347 }
1348 
1349 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1350  OperationState &result) {
1352 }
1353 
1354 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1355  ::printArithmeticExtendedBinaryOp(*this, printer);
1356 }
1357 
1358 //===----------------------------------------------------------------------===//
1359 // spirv.ISubBorrowOp
1360 //===----------------------------------------------------------------------===//
1361 
1362 LogicalResult spirv::ISubBorrowOp::verify() {
1364 }
1365 
1366 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1367  OperationState &result) {
1369 }
1370 
1372  ::printArithmeticExtendedBinaryOp(*this, printer);
1373 }
1374 
1375 //===----------------------------------------------------------------------===//
1376 // spirv.SMulExtended
1377 //===----------------------------------------------------------------------===//
1378 
1379 LogicalResult spirv::SMulExtendedOp::verify() {
1381 }
1382 
1383 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1384  OperationState &result) {
1386 }
1387 
1389  ::printArithmeticExtendedBinaryOp(*this, printer);
1390 }
1391 
1392 //===----------------------------------------------------------------------===//
1393 // spirv.UMulExtended
1394 //===----------------------------------------------------------------------===//
1395 
1396 LogicalResult spirv::UMulExtendedOp::verify() {
1398 }
1399 
1400 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1401  OperationState &result) {
1403 }
1404 
1406  ::printArithmeticExtendedBinaryOp(*this, printer);
1407 }
1408 
1409 //===----------------------------------------------------------------------===//
1410 // spirv.MemoryBarrierOp
1411 //===----------------------------------------------------------------------===//
1412 
1413 LogicalResult spirv::MemoryBarrierOp::verify() {
1414  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1415 }
1416 
1417 //===----------------------------------------------------------------------===//
1418 // spirv.module
1419 //===----------------------------------------------------------------------===//
1420 
1421 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1422  std::optional<StringRef> name) {
1423  OpBuilder::InsertionGuard guard(builder);
1424  builder.createBlock(state.addRegion());
1425  if (name) {
1426  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1427  builder.getStringAttr(*name));
1428  }
1429 }
1430 
1431 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1432  spirv::AddressingModel addressingModel,
1433  spirv::MemoryModel memoryModel,
1434  std::optional<VerCapExtAttr> vceTriple,
1435  std::optional<StringRef> name) {
1436  state.addAttribute(
1437  "addressing_model",
1438  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1439  state.addAttribute("memory_model",
1440  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1441  OpBuilder::InsertionGuard guard(builder);
1442  builder.createBlock(state.addRegion());
1443  if (vceTriple)
1444  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1445  if (name)
1446  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1447  builder.getStringAttr(*name));
1448 }
1449 
1450 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1451  OperationState &result) {
1452  Region *body = result.addRegion();
1453 
1454  // If the name is present, parse it.
1455  StringAttr nameAttr;
1456  (void)parser.parseOptionalSymbolName(
1457  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1458 
1459  // Parse attributes
1460  spirv::AddressingModel addrModel;
1461  spirv::MemoryModel memoryModel;
1462  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1463  result) ||
1464  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1465  result))
1466  return failure();
1467 
1468  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1469  spirv::VerCapExtAttr vceTriple;
1470  if (parser.parseAttribute(vceTriple,
1471  spirv::ModuleOp::getVCETripleAttrName(),
1472  result.attributes))
1473  return failure();
1474  }
1475 
1476  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1477  parser.parseRegion(*body, /*arguments=*/{}))
1478  return failure();
1479 
1480  // Make sure we have at least one block.
1481  if (body->empty())
1482  body->push_back(new Block());
1483 
1484  return success();
1485 }
1486 
1487 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1488  if (std::optional<StringRef> name = getName()) {
1489  printer << ' ';
1490  printer.printSymbolName(*name);
1491  }
1492 
1493  SmallVector<StringRef, 2> elidedAttrs;
1494 
1495  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1496  << spirv::stringifyMemoryModel(getMemoryModel());
1497  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1498  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1499  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1501 
1502  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1503  printer << " requires " << *triple;
1504  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1505  }
1506 
1507  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1508  printer << ' ';
1509  printer.printRegion(getRegion());
1510 }
1511 
1512 LogicalResult spirv::ModuleOp::verifyRegions() {
1513  Dialect *dialect = (*this)->getDialect();
1515  entryPoints;
1516  mlir::SymbolTable table(*this);
1517 
1518  for (auto &op : *getBody()) {
1519  if (op.getDialect() != dialect)
1520  return op.emitError("'spirv.module' can only contain spirv.* ops");
1521 
1522  // For EntryPoint op, check that the function and execution model is not
1523  // duplicated in EntryPointOps. Also verify that the interface specified
1524  // comes from globalVariables here to make this check cheaper.
1525  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1526  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1527  if (!funcOp) {
1528  return entryPointOp.emitError("function '")
1529  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1530  }
1531  if (auto interface = entryPointOp.getInterface()) {
1532  for (Attribute varRef : interface) {
1533  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1534  if (!varSymRef) {
1535  return entryPointOp.emitError(
1536  "expected symbol reference for interface "
1537  "specification instead of '")
1538  << varRef;
1539  }
1540  auto variableOp =
1541  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1542  if (!variableOp) {
1543  return entryPointOp.emitError("expected spirv.GlobalVariable "
1544  "symbol reference instead of'")
1545  << varSymRef << "'";
1546  }
1547  }
1548  }
1549 
1550  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1551  funcOp, entryPointOp.getExecutionModel());
1552  if (!entryPoints.try_emplace(key, entryPointOp).second)
1553  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1554  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1555  // If the function is external and does not have 'Import'
1556  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1557  // LinkageAttributes is used to import external functions.
1558  auto linkageAttr = funcOp.getLinkageAttributes();
1559  auto hasImportLinkage =
1560  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1561  spirv::LinkageType::Import);
1562  if (funcOp.isExternal() && !hasImportLinkage)
1563  return op.emitError(
1564  "'spirv.module' cannot contain external functions "
1565  "without 'Import' linkage_attributes (LinkageAttributes)");
1566 
1567  // TODO: move this check to spirv.func.
1568  for (auto &block : funcOp)
1569  for (auto &op : block) {
1570  if (op.getDialect() != dialect)
1571  return op.emitError(
1572  "functions in 'spirv.module' can only contain spirv.* ops");
1573  }
1574  }
1575  }
1576 
1577  return success();
1578 }
1579 
1580 //===----------------------------------------------------------------------===//
1581 // spirv.mlir.referenceof
1582 //===----------------------------------------------------------------------===//
1583 
1584 LogicalResult spirv::ReferenceOfOp::verify() {
1585  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1586  (*this)->getParentOp(), getSpecConstAttr());
1587  Type constType;
1588 
1589  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1590  if (specConstOp)
1591  constType = specConstOp.getDefaultValue().getType();
1592 
1593  auto specConstCompositeOp =
1594  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1595  if (specConstCompositeOp)
1596  constType = specConstCompositeOp.getType();
1597 
1598  if (!specConstOp && !specConstCompositeOp)
1599  return emitOpError(
1600  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1601 
1602  if (getReference().getType() != constType)
1603  return emitOpError("result type mismatch with the referenced "
1604  "specialization constant's type");
1605 
1606  return success();
1607 }
1608 
1609 //===----------------------------------------------------------------------===//
1610 // spirv.SpecConstant
1611 //===----------------------------------------------------------------------===//
1612 
1613 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1614  OperationState &result) {
1615  StringAttr nameAttr;
1616  Attribute valueAttr;
1617  StringRef defaultValueAttrName =
1618  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1619 
1620  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1621  result.attributes))
1622  return failure();
1623 
1624  // Parse optional spec_id.
1625  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1626  IntegerAttr specIdAttr;
1627  if (parser.parseLParen() ||
1628  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1629  parser.parseRParen())
1630  return failure();
1631  }
1632 
1633  if (parser.parseEqual() ||
1634  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1635  return failure();
1636 
1637  return success();
1638 }
1639 
1641  printer << ' ';
1642  printer.printSymbolName(getSymName());
1643  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1644  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1645  printer << " = " << getDefaultValue();
1646 }
1647 
1648 LogicalResult spirv::SpecConstantOp::verify() {
1649  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1650  if (specID.getValue().isNegative())
1651  return emitOpError("SpecId cannot be negative");
1652 
1653  auto value = getDefaultValue();
1654  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1655  // Make sure bitwidth is allowed.
1656  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1657  return emitOpError("default value bitwidth disallowed");
1658  return success();
1659  }
1660  return emitOpError(
1661  "default value can only be a bool, integer, or float scalar");
1662 }
1663 
1664 //===----------------------------------------------------------------------===//
1665 // spirv.VectorShuffle
1666 //===----------------------------------------------------------------------===//
1667 
1668 LogicalResult spirv::VectorShuffleOp::verify() {
1669  VectorType resultType = llvm::cast<VectorType>(getType());
1670 
1671  size_t numResultElements = resultType.getNumElements();
1672  if (numResultElements != getComponents().size())
1673  return emitOpError("result type element count (")
1674  << numResultElements
1675  << ") mismatch with the number of component selectors ("
1676  << getComponents().size() << ")";
1677 
1678  size_t totalSrcElements =
1679  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1680  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1681 
1682  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1683  uint32_t index = selector.getZExtValue();
1684  if (index >= totalSrcElements &&
1685  index != std::numeric_limits<uint32_t>().max())
1686  return emitOpError("component selector ")
1687  << index << " out of range: expected to be in [0, "
1688  << totalSrcElements << ") or 0xffffffff";
1689  }
1690  return success();
1691 }
1692 
1693 //===----------------------------------------------------------------------===//
1694 // spirv.MatrixTimesScalar
1695 //===----------------------------------------------------------------------===//
1696 
1697 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1698  Type elementType =
1699  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1701  [](auto matrixType) { return matrixType.getElementType(); })
1702  .Default([](Type) { return nullptr; });
1703 
1704  assert(elementType && "Unhandled type");
1705 
1706  // Check that the scalar type is the same as the matrix element type.
1707  if (getScalar().getType() != elementType)
1708  return emitOpError("input matrix components' type and scaling value must "
1709  "have the same type");
1710 
1711  return success();
1712 }
1713 
1714 //===----------------------------------------------------------------------===//
1715 // spirv.Transpose
1716 //===----------------------------------------------------------------------===//
1717 
1718 LogicalResult spirv::TransposeOp::verify() {
1719  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1720  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1721 
1722  // Verify that the input and output matrices have correct shapes.
1723  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1724  return emitError("input matrix rows count must be equal to "
1725  "output matrix columns count");
1726 
1727  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1728  return emitError("input matrix columns count must be equal to "
1729  "output matrix rows count");
1730 
1731  // Verify that the input and output matrices have the same component type
1732  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1733  return emitError("input and output matrices must have the same "
1734  "component type");
1735 
1736  return success();
1737 }
1738 
1739 //===----------------------------------------------------------------------===//
1740 // spirv.MatrixTimesVector
1741 //===----------------------------------------------------------------------===//
1742 
1743 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1744  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1745  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1746  auto resultType = llvm::cast<VectorType>(getType());
1747 
1748  if (matrixType.getNumColumns() != vectorType.getNumElements())
1749  return emitOpError("matrix columns (")
1750  << matrixType.getNumColumns() << ") must match vector operand size ("
1751  << vectorType.getNumElements() << ")";
1752 
1753  if (resultType.getNumElements() != matrixType.getNumRows())
1754  return emitOpError("result size (")
1755  << resultType.getNumElements() << ") must match the matrix rows ("
1756  << matrixType.getNumRows() << ")";
1757 
1758  if (matrixType.getElementType() != resultType.getElementType())
1759  return emitOpError("matrix and result element types must match");
1760 
1761  return success();
1762 }
1763 
1764 //===----------------------------------------------------------------------===//
1765 // spirv.VectorTimesMatrix
1766 //===----------------------------------------------------------------------===//
1767 
1768 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1769  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1770  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1771  auto resultType = llvm::cast<VectorType>(getType());
1772 
1773  if (matrixType.getNumRows() != vectorType.getNumElements())
1774  return emitOpError("number of components in vector must equal the number "
1775  "of components in each column in matrix");
1776 
1777  if (resultType.getNumElements() != matrixType.getNumColumns())
1778  return emitOpError("number of columns in matrix must equal the number of "
1779  "components in result");
1780 
1781  if (matrixType.getElementType() != resultType.getElementType())
1782  return emitOpError("matrix must be a matrix with the same component type "
1783  "as the component type in result");
1784 
1785  return success();
1786 }
1787 
1788 //===----------------------------------------------------------------------===//
1789 // spirv.MatrixTimesMatrix
1790 //===----------------------------------------------------------------------===//
1791 
1792 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1793  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1794  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1795  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1796 
1797  // left matrix columns' count and right matrix rows' count must be equal
1798  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1799  return emitError("left matrix columns' count must be equal to "
1800  "the right matrix rows' count");
1801 
1802  // right and result matrices columns' count must be the same
1803  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1804  return emitError(
1805  "right and result matrices must have equal columns' count");
1806 
1807  // right and result matrices component type must be the same
1808  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1809  return emitError("right and result matrices' component type must"
1810  " be the same");
1811 
1812  // left and result matrices component type must be the same
1813  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1814  return emitError("left and result matrices' component type"
1815  " must be the same");
1816 
1817  // left and result matrices rows count must be the same
1818  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1819  return emitError("left and result matrices must have equal rows' count");
1820 
1821  return success();
1822 }
1823 
1824 //===----------------------------------------------------------------------===//
1825 // spirv.SpecConstantComposite
1826 //===----------------------------------------------------------------------===//
1827 
1829  OperationState &result) {
1830 
1831  StringAttr compositeName;
1832  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1833  result.attributes))
1834  return failure();
1835 
1836  if (parser.parseLParen())
1837  return failure();
1838 
1839  SmallVector<Attribute, 4> constituents;
1840 
1841  do {
1842  // The name of the constituent attribute isn't important
1843  const char *attrName = "spec_const";
1844  FlatSymbolRefAttr specConstRef;
1845  NamedAttrList attrs;
1846 
1847  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1848  return failure();
1849 
1850  constituents.push_back(specConstRef);
1851  } while (!parser.parseOptionalComma());
1852 
1853  if (parser.parseRParen())
1854  return failure();
1855 
1856  StringAttr compositeSpecConstituentsName =
1857  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1858  result.addAttribute(compositeSpecConstituentsName,
1859  parser.getBuilder().getArrayAttr(constituents));
1860 
1861  Type type;
1862  if (parser.parseColonType(type))
1863  return failure();
1864 
1865  StringAttr typeAttrName =
1866  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1867  result.addAttribute(typeAttrName, TypeAttr::get(type));
1868 
1869  return success();
1870 }
1871 
1873  printer << " ";
1874  printer.printSymbolName(getSymName());
1875  printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1876  << ") : " << getType();
1877 }
1878 
1880  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1881  auto constituents = this->getConstituents().getValue();
1882 
1883  if (!cType)
1884  return emitError("result type must be a composite type, but provided ")
1885  << getType();
1886 
1887  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1888  return emitError("unsupported composite type ") << cType;
1889  if (constituents.size() != cType.getNumElements())
1890  return emitError("has incorrect number of operands: expected ")
1891  << cType.getNumElements() << ", but provided "
1892  << constituents.size();
1893 
1894  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1895  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1896 
1897  auto constituentSpecConstOp =
1898  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1899  (*this)->getParentOp(), constituent.getAttr()));
1900 
1901  if (constituentSpecConstOp.getDefaultValue().getType() !=
1902  cType.getElementType(index))
1903  return emitError("has incorrect types of operands: expected ")
1904  << cType.getElementType(index) << ", but provided "
1905  << constituentSpecConstOp.getDefaultValue().getType();
1906  }
1907 
1908  return success();
1909 }
1910 
1911 //===----------------------------------------------------------------------===//
1912 // spirv.EXTSpecConstantCompositeReplicateOp
1913 //===----------------------------------------------------------------------===//
1914 
1915 ParseResult
1917  OperationState &result) {
1918  StringAttr compositeName;
1919  FlatSymbolRefAttr specConstRef;
1920  const char *attrName = "spec_const";
1921  NamedAttrList attrs;
1922  Type type;
1923 
1924  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1925  result.attributes) ||
1926  parser.parseLParen() ||
1927  parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1928  parser.parseRParen() || parser.parseColonType(type))
1929  return failure();
1930 
1931  StringAttr compositeSpecConstituentName =
1932  spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1933  result.name);
1934  result.addAttribute(compositeSpecConstituentName, specConstRef);
1935 
1936  StringAttr typeAttrName =
1937  spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1938  result.addAttribute(typeAttrName, TypeAttr::get(type));
1939 
1940  return success();
1941 }
1942 
1944  printer << " ";
1945  printer.printSymbolName(getSymName());
1946  printer << " (" << this->getConstituent() << ") : " << getType();
1947 }
1948 
1950  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1951  if (!compositeType)
1952  return emitError("result type must be a composite type, but provided ")
1953  << getType();
1954 
1956  (*this)->getParentOp(), this->getConstituent());
1957  if (!constituentOp)
1958  return emitError(
1959  "splat spec constant reference defining constituent not found");
1960 
1961  auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1962  if (!constituentSpecConstOp)
1963  return emitError("constituent is not a spec constant");
1964 
1965  Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1966  Type compositeElementType = compositeType.getElementType(0);
1967  if (constituentType != compositeElementType)
1968  return emitError("constituent has incorrect type: expected ")
1969  << compositeElementType << ", but provided " << constituentType;
1970 
1971  return success();
1972 }
1973 
1974 //===----------------------------------------------------------------------===//
1975 // spirv.SpecConstantOperation
1976 //===----------------------------------------------------------------------===//
1977 
1979  OperationState &result) {
1980  Region *body = result.addRegion();
1981 
1982  if (parser.parseKeyword("wraps"))
1983  return failure();
1984 
1985  body->push_back(new Block);
1986  Block &block = body->back();
1987  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1988 
1989  if (!wrappedOp)
1990  return failure();
1991 
1992  OpBuilder builder(parser.getContext());
1993  builder.setInsertionPointToEnd(&block);
1994  spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
1995  result.location = wrappedOp->getLoc();
1996 
1997  result.addTypes(wrappedOp->getResult(0).getType());
1998 
1999  if (parser.parseOptionalAttrDict(result.attributes))
2000  return failure();
2001 
2002  return success();
2003 }
2004 
2006  printer << " wraps ";
2007  printer.printGenericOp(&getBody().front().front());
2008 }
2009 
2010 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2011  Block &block = getRegion().getBlocks().front();
2012 
2013  if (block.getOperations().size() != 2)
2014  return emitOpError("expected exactly 2 nested ops");
2015 
2016  Operation &enclosedOp = block.getOperations().front();
2017 
2019  return emitOpError("invalid enclosed op");
2020 
2021  for (auto operand : enclosedOp.getOperands())
2022  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2023  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2024  return emitOpError(
2025  "invalid operand, must be defined by a constant operation");
2026 
2027  return success();
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // spirv.GL.FrexpStruct
2032 //===----------------------------------------------------------------------===//
2033 
2034 LogicalResult spirv::GLFrexpStructOp::verify() {
2035  spirv::StructType structTy =
2036  llvm::dyn_cast<spirv::StructType>(getResult().getType());
2037 
2038  if (structTy.getNumElements() != 2)
2039  return emitError("result type must be a struct type with two memebers");
2040 
2041  Type significandTy = structTy.getElementType(0);
2042  Type exponentTy = structTy.getElementType(1);
2043  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
2044  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
2045 
2046  Type operandTy = getOperand().getType();
2047  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
2048  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
2049 
2050  if (significandTy != operandTy)
2051  return emitError("member zero of the resulting struct type must be the "
2052  "same type as the operand");
2053 
2054  if (exponentVecTy) {
2055  IntegerType componentIntTy =
2056  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
2057  if (!componentIntTy || componentIntTy.getWidth() != 32)
2058  return emitError("member one of the resulting struct type must"
2059  "be a scalar or vector of 32 bit integer type");
2060  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2061  return emitError("member one of the resulting struct type "
2062  "must be a scalar or vector of 32 bit integer type");
2063  }
2064 
2065  // Check that the two member types have the same number of components
2066  if (operandVecTy && exponentVecTy &&
2067  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2068  return success();
2069 
2070  if (operandFTy && exponentIntTy)
2071  return success();
2072 
2073  return emitError("member one of the resulting struct type must have the same "
2074  "number of components as the operand type");
2075 }
2076 
2077 //===----------------------------------------------------------------------===//
2078 // spirv.GL.Ldexp
2079 //===----------------------------------------------------------------------===//
2080 
2081 LogicalResult spirv::GLLdexpOp::verify() {
2082  Type significandType = getX().getType();
2083  Type exponentType = getExp().getType();
2084 
2085  if (llvm::isa<FloatType>(significandType) !=
2086  llvm::isa<IntegerType>(exponentType))
2087  return emitOpError("operands must both be scalars or vectors");
2088 
2089  auto getNumElements = [](Type type) -> unsigned {
2090  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2091  return vectorType.getNumElements();
2092  return 1;
2093  };
2094 
2095  if (getNumElements(significandType) != getNumElements(exponentType))
2096  return emitOpError("operands must have the same number of elements");
2097 
2098  return success();
2099 }
2100 
2101 //===----------------------------------------------------------------------===//
2102 // spirv.ShiftLeftLogicalOp
2103 //===----------------------------------------------------------------------===//
2104 
2105 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2106  return verifyShiftOp(*this);
2107 }
2108 
2109 //===----------------------------------------------------------------------===//
2110 // spirv.ShiftRightArithmeticOp
2111 //===----------------------------------------------------------------------===//
2112 
2113 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2114  return verifyShiftOp(*this);
2115 }
2116 
2117 //===----------------------------------------------------------------------===//
2118 // spirv.ShiftRightLogicalOp
2119 //===----------------------------------------------------------------------===//
2120 
2121 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2122  return verifyShiftOp(*this);
2123 }
2124 
2125 //===----------------------------------------------------------------------===//
2126 // spirv.VectorTimesScalarOp
2127 //===----------------------------------------------------------------------===//
2128 
2129 LogicalResult spirv::VectorTimesScalarOp::verify() {
2130  if (getVector().getType() != getType())
2131  return emitOpError("vector operand and result type mismatch");
2132  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2133  if (getScalar().getType() != scalarType)
2134  return emitOpError("scalar operand and result element type match");
2135  return success();
2136 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:267
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:563
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:253
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:299
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:186
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:291
const float * table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:271
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
SPIR-V struct type.
Definition: SPIRVTypes.h:295
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
static Type getValueType(Attribute attr)
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.