MLIR  22.0.0git
SPIRVOps.cpp
Go to the documentation of this file.
1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/TypeUtilities.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/InterleavedRange.h"
38 #include <cassert>
39 #include <numeric>
40 #include <optional>
41 
42 using namespace mlir;
43 using namespace mlir::spirv::AttrNames;
44 
45 //===----------------------------------------------------------------------===//
46 // Common utility functions
47 //===----------------------------------------------------------------------===//
48 
49 LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
50  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
51  if (!constOp) {
52  return failure();
53  }
54  auto valueAttr = constOp.getValue();
55  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
56  if (!integerValueAttr) {
57  return failure();
58  }
59 
60  if (integerValueAttr.getType().isSignlessInteger())
61  value = integerValueAttr.getInt();
62  else
63  value = integerValueAttr.getSInt();
64 
65  return success();
66 }
67 
68 LogicalResult
70  spirv::MemorySemantics memorySemantics) {
71  // According to the SPIR-V specification:
72  // "Despite being a mask and allowing multiple bits to be combined, it is
73  // invalid for more than one of these four bits to be set: Acquire, Release,
74  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
75  // Release semantics is done by setting the AcquireRelease bit, not by setting
76  // two bits."
77  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78  spirv::MemorySemantics::Release |
79  spirv::MemorySemantics::AcquireRelease |
80  spirv::MemorySemantics::SequentiallyConsistent;
81 
82  auto bitCount =
83  llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
84  if (bitCount > 1) {
85  return op->emitError(
86  "expected at most one of these four memory constraints "
87  "to be set: `Acquire`, `Release`,"
88  "`AcquireRelease` or `SequentiallyConsistent`");
89  }
90  return success();
91 }
92 
94  SmallVectorImpl<StringRef> &elidedAttrs) {
95  // Print optional descriptor binding
96  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
97  stringifyDecoration(spirv::Decoration::DescriptorSet));
98  auto bindingName = llvm::convertToSnakeFromCamelCase(
99  stringifyDecoration(spirv::Decoration::Binding));
100  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
101  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
102  if (descriptorSet && binding) {
103  elidedAttrs.push_back(descriptorSetName);
104  elidedAttrs.push_back(bindingName);
105  printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
106  << ")";
107  }
108 
109  // Print BuiltIn attribute if present
110  auto builtInName = llvm::convertToSnakeFromCamelCase(
111  stringifyDecoration(spirv::Decoration::BuiltIn));
112  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
113  printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
114  elidedAttrs.push_back(builtInName);
115  }
116 
117  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
118 }
119 
120 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
121  OperationState &result) {
123  Type type;
124  // If the operand list is in-between parentheses, then we have a generic form.
125  // (see the fallback in `printOneResultOp`).
126  SMLoc loc = parser.getCurrentLocation();
127  if (!parser.parseOptionalLParen()) {
128  if (parser.parseOperandList(ops) || parser.parseRParen() ||
129  parser.parseOptionalAttrDict(result.attributes) ||
130  parser.parseColon() || parser.parseType(type))
131  return failure();
132  auto fnType = llvm::dyn_cast<FunctionType>(type);
133  if (!fnType) {
134  parser.emitError(loc, "expected function type");
135  return failure();
136  }
137  if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
138  return failure();
139  result.addTypes(fnType.getResults());
140  return success();
141  }
142  return failure(parser.parseOperandList(ops) ||
143  parser.parseOptionalAttrDict(result.attributes) ||
144  parser.parseColonType(type) ||
145  parser.resolveOperands(ops, type, result.operands) ||
146  parser.addTypeToList(type, result.types));
147 }
148 
150  assert(op->getNumResults() == 1 && "op should have one result");
151 
152  // If not all the operand and result types are the same, just use the
153  // generic assembly form to avoid omitting information in printing.
154  auto resultType = op->getResult(0).getType();
155  if (llvm::any_of(op->getOperandTypes(),
156  [&](Type type) { return type != resultType; })) {
157  p.printGenericOp(op, /*printOpName=*/false);
158  return;
159  }
160 
161  p << ' ';
162  p.printOperands(op->getOperands());
164  // Now we can output only one type for all operands and the result.
165  p << " : " << resultType;
166 }
167 
168 template <typename BlockReadWriteOpTy>
169 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
170  Value ptr, Value val) {
171  auto valType = val.getType();
172  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
173  valType = valVecTy.getElementType();
174 
175  if (valType !=
176  llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
177  return op.emitOpError("mismatch in result type and pointer type");
178  }
179  return success();
180 }
181 
182 /// Walks the given type hierarchy with the given indices, potentially down
183 /// to component granularity, to select an element type. Returns null type and
184 /// emits errors with the given loc on failure.
185 static Type
187  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
188  if (indices.empty()) {
189  emitErrorFn("expected at least one index for spirv.CompositeExtract");
190  return nullptr;
191  }
192 
193  for (auto index : indices) {
194  if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
195  if (cType.hasCompileTimeKnownNumElements() &&
196  (index < 0 ||
197  static_cast<uint64_t>(index) >= cType.getNumElements())) {
198  emitErrorFn("index ") << index << " out of bounds for " << type;
199  return nullptr;
200  }
201  type = cType.getElementType(index);
202  } else {
203  emitErrorFn("cannot extract from non-composite type ")
204  << type << " with index " << index;
205  return nullptr;
206  }
207  }
208  return type;
209 }
210 
211 static Type
213  function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
214  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
215  if (!indicesArrayAttr) {
216  emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
217  return nullptr;
218  }
219  if (indicesArrayAttr.empty()) {
220  emitErrorFn("expected at least one index for spirv.CompositeExtract");
221  return nullptr;
222  }
223 
224  SmallVector<int32_t, 2> indexVals;
225  for (auto indexAttr : indicesArrayAttr) {
226  auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
227  if (!indexIntAttr) {
228  emitErrorFn("expected an 32-bit integer for index, but found '")
229  << indexAttr << "'";
230  return nullptr;
231  }
232  indexVals.push_back(indexIntAttr.getInt());
233  }
234  return getElementType(type, indexVals, emitErrorFn);
235 }
236 
237 static Type getElementType(Type type, Attribute indices, Location loc) {
238  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
239  return ::mlir::emitError(loc, err);
240  };
241  return getElementType(type, indices, errorFn);
242 }
243 
244 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
245  SMLoc loc) {
246  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
247  return parser.emitError(loc, err);
248  };
249  return getElementType(type, indices, errorFn);
250 }
251 
252 template <typename ExtendedBinaryOp>
253 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
254  auto resultType = llvm::cast<spirv::StructType>(op.getType());
255  if (resultType.getNumElements() != 2)
256  return op.emitOpError("expected result struct type containing two members");
257 
258  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
259  resultType.getElementType(0),
260  resultType.getElementType(1)}))
261  return op.emitOpError(
262  "expected all operand types and struct member types are the same");
263 
264  return success();
265 }
266 
267 static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
268  OperationState &result) {
270  if (parser.parseOptionalAttrDict(result.attributes) ||
271  parser.parseOperandList(operands) || parser.parseColon())
272  return failure();
273 
274  Type resultType;
275  SMLoc loc = parser.getCurrentLocation();
276  if (parser.parseType(resultType))
277  return failure();
278 
279  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
280  if (!structType || structType.getNumElements() != 2)
281  return parser.emitError(loc, "expected spirv.struct type with two members");
282 
283  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
284  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
285  return failure();
286 
287  result.addTypes(resultType);
288  return success();
289 }
290 
292  OpAsmPrinter &printer) {
293  printer << ' ';
294  printer.printOptionalAttrDict(op->getAttrs());
295  printer.printOperands(op->getOperands());
296  printer << " : " << op->getResultTypes().front();
297 }
298 
299 static LogicalResult verifyShiftOp(Operation *op) {
300  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
301  return op->emitError("expected the same type for the first operand and "
302  "result, but provided ")
303  << op->getOperand(0).getType() << " and "
304  << op->getResult(0).getType();
305  }
306  return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // spirv.mlir.addressof
311 //===----------------------------------------------------------------------===//
312 
313 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
314  spirv::GlobalVariableOp var) {
315  build(builder, state, var.getType(), SymbolRefAttr::get(var));
316 }
317 
318 LogicalResult spirv::AddressOfOp::verify() {
319  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
320  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
321  getVariableAttr()));
322  if (!varOp) {
323  return emitOpError("expected spirv.GlobalVariable symbol");
324  }
325  if (getPointer().getType() != varOp.getType()) {
326  return emitOpError(
327  "result type mismatch with the referenced global variable's type");
328  }
329  return success();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // spirv.CompositeConstruct
334 //===----------------------------------------------------------------------===//
335 
336 LogicalResult spirv::CompositeConstructOp::verify() {
337  operand_range constituents = this->getConstituents();
338 
339  // There are 4 cases with varying verification rules:
340  // 1. Cooperative Matrices (1 constituent)
341  // 2. Structs (1 constituent for each member)
342  // 3. Arrays (1 constituent for each array element)
343  // 4. Vectors (1 constituent (sub-)element for each vector element)
344 
345  auto coopElementType =
348  [](auto coopType) { return coopType.getElementType(); })
349  .Default(nullptr);
350 
351  // Case 1. -- matrices.
352  if (coopElementType) {
353  if (constituents.size() != 1)
354  return emitOpError("has incorrect number of operands: expected ")
355  << "1, but provided " << constituents.size();
356  if (coopElementType != constituents.front().getType())
357  return emitOpError("operand type mismatch: expected operand type ")
358  << coopElementType << ", but provided "
359  << constituents.front().getType();
360  return success();
361  }
362 
363  // Case 2./3./4. -- number of constituents matches the number of elements.
364  auto cType = llvm::cast<spirv::CompositeType>(getType());
365  if (constituents.size() == cType.getNumElements()) {
366  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367  if (constituents[index].getType() != cType.getElementType(index)) {
368  return emitOpError("operand type mismatch: expected operand type ")
369  << cType.getElementType(index) << ", but provided "
370  << constituents[index].getType();
371  }
372  }
373  return success();
374  }
375 
376  // Case 4. -- check that all constituents add up tp the expected vector type.
377  auto resultType = llvm::dyn_cast<VectorType>(cType);
378  if (!resultType)
379  return emitOpError(
380  "expected to return a vector or cooperative matrix when the number of "
381  "constituents is less than what the result needs");
382 
383  SmallVector<unsigned> sizes;
384  for (Value component : constituents) {
385  if (!llvm::isa<VectorType>(component.getType()) &&
386  !component.getType().isIntOrFloat())
387  return emitOpError("operand type mismatch: expected operand to have "
388  "a scalar or vector type, but provided ")
389  << component.getType();
390 
391  Type elementType = component.getType();
392  if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
393  sizes.push_back(vectorType.getNumElements());
394  elementType = vectorType.getElementType();
395  } else {
396  sizes.push_back(1);
397  }
398 
399  if (elementType != resultType.getElementType())
400  return emitOpError("operand element type mismatch: expected to be ")
401  << resultType.getElementType() << ", but provided " << elementType;
402  }
403  unsigned totalCount = llvm::sum_of(sizes);
404  if (totalCount != cType.getNumElements())
405  return emitOpError("has incorrect number of operands: expected ")
406  << cType.getNumElements() << ", but provided " << totalCount;
407  return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // spirv.CompositeExtractOp
412 //===----------------------------------------------------------------------===//
413 
414 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
415  Value composite,
416  ArrayRef<int32_t> indices) {
417  auto indexAttr = builder.getI32ArrayAttr(indices);
418  auto elementType =
419  getElementType(composite.getType(), indexAttr, state.location);
420  if (!elementType) {
421  return;
422  }
423  build(builder, state, elementType, composite, indexAttr);
424 }
425 
427  OperationState &result) {
428  OpAsmParser::UnresolvedOperand compositeInfo;
429  Attribute indicesAttr;
430  StringRef indicesAttrName =
431  spirv::CompositeExtractOp::getIndicesAttrName(result.name);
432  Type compositeType;
433  SMLoc attrLocation;
434 
435  if (parser.parseOperand(compositeInfo) ||
436  parser.getCurrentLocation(&attrLocation) ||
437  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
438  parser.parseColonType(compositeType) ||
439  parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
440  return failure();
441  }
442 
443  Type resultType =
444  getElementType(compositeType, indicesAttr, parser, attrLocation);
445  if (!resultType) {
446  return failure();
447  }
448  result.addTypes(resultType);
449  return success();
450 }
451 
453  printer << ' ' << getComposite() << getIndices() << " : "
454  << getComposite().getType();
455 }
456 
457 LogicalResult spirv::CompositeExtractOp::verify() {
458  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
459  auto resultType =
460  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
461  if (!resultType)
462  return failure();
463 
464  if (resultType != getType()) {
465  return emitOpError("invalid result type: expected ")
466  << resultType << " but provided " << getType();
467  }
468 
469  return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // spirv.CompositeInsert
474 //===----------------------------------------------------------------------===//
475 
476 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
477  Value object, Value composite,
478  ArrayRef<int32_t> indices) {
479  auto indexAttr = builder.getI32ArrayAttr(indices);
480  build(builder, state, composite.getType(), object, composite, indexAttr);
481 }
482 
483 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
484  OperationState &result) {
486  Type objectType, compositeType;
487  Attribute indicesAttr;
488  StringRef indicesAttrName =
489  spirv::CompositeInsertOp::getIndicesAttrName(result.name);
490  auto loc = parser.getCurrentLocation();
491 
492  return failure(
493  parser.parseOperandList(operands, 2) ||
494  parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
495  parser.parseColonType(objectType) ||
496  parser.parseKeywordType("into", compositeType) ||
497  parser.resolveOperands(operands, {objectType, compositeType}, loc,
498  result.operands) ||
499  parser.addTypesToList(compositeType, result.types));
500 }
501 
502 LogicalResult spirv::CompositeInsertOp::verify() {
503  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
504  auto objectType =
505  getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
506  if (!objectType)
507  return failure();
508 
509  if (objectType != getObject().getType()) {
510  return emitOpError("object operand type should be ")
511  << objectType << ", but found " << getObject().getType();
512  }
513 
514  if (getComposite().getType() != getType()) {
515  return emitOpError("result type should be the same as "
516  "the composite type, but found ")
517  << getComposite().getType() << " vs " << getType();
518  }
519 
520  return success();
521 }
522 
524  printer << " " << getObject() << ", " << getComposite() << getIndices()
525  << " : " << getObject().getType() << " into "
526  << getComposite().getType();
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // spirv.Constant
531 //===----------------------------------------------------------------------===//
532 
533 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
534  OperationState &result) {
535  Attribute value;
536  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
537  if (parser.parseAttribute(value, valueAttrName, result.attributes))
538  return failure();
539 
540  Type type = NoneType::get(parser.getContext());
541  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
542  type = typedAttr.getType();
543  if (llvm::isa<NoneType, TensorType>(type)) {
544  if (parser.parseColonType(type))
545  return failure();
546  }
547 
548  if (llvm::isa<TensorArmType>(type)) {
549  if (parser.parseOptionalColon().succeeded())
550  if (parser.parseType(type))
551  return failure();
552  }
553 
554  return parser.addTypeToList(type, result.types);
555 }
556 
557 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
558  printer << ' ' << getValue();
559  if (llvm::isa<spirv::ArrayType>(getType()))
560  printer << " : " << getType();
561 }
562 
563 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
564  Type opType) {
565  if (isa<spirv::CooperativeMatrixType>(opType)) {
566  auto denseAttr = dyn_cast<DenseElementsAttr>(value);
567  if (!denseAttr || !denseAttr.isSplat())
568  return op.emitOpError("expected a splat dense attribute for cooperative "
569  "matrix constant, but found ")
570  << denseAttr;
571  }
572  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
573  auto valueType = llvm::cast<TypedAttr>(value).getType();
574  if (valueType != opType)
575  return op.emitOpError("result type (")
576  << opType << ") does not match value type (" << valueType << ")";
577  return success();
578  }
579  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
580  auto valueType = llvm::cast<TypedAttr>(value).getType();
581  if (valueType == opType)
582  return success();
583  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
584  auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
585  if (!arrayType)
586  return op.emitOpError("result or element type (")
587  << opType << ") does not match value type (" << valueType
588  << "), must be the same or spirv.array";
589 
590  int numElements = arrayType.getNumElements();
591  auto opElemType = arrayType.getElementType();
592  while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
593  numElements *= t.getNumElements();
594  opElemType = t.getElementType();
595  }
596  if (!opElemType.isIntOrFloat())
597  return op.emitOpError("only support nested array result type");
598 
599  auto valueElemType = shapedType.getElementType();
600  if (valueElemType != opElemType) {
601  return op.emitOpError("result element type (")
602  << opElemType << ") does not match value element type ("
603  << valueElemType << ")";
604  }
605 
606  if (numElements != shapedType.getNumElements()) {
607  return op.emitOpError("result number of elements (")
608  << numElements << ") does not match value number of elements ("
609  << shapedType.getNumElements() << ")";
610  }
611  return success();
612  }
613  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
614  auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
615  if (!arrayType)
616  return op.emitOpError(
617  "must have spirv.array result type for array value");
618  Type elemType = arrayType.getElementType();
619  for (Attribute element : arrayAttr.getValue()) {
620  // Verify array elements recursively.
621  if (failed(verifyConstantType(op, element, elemType)))
622  return failure();
623  }
624  return success();
625  }
626  return op.emitOpError("cannot have attribute: ") << value;
627 }
628 
629 LogicalResult spirv::ConstantOp::verify() {
630  // ODS already generates checks to make sure the result type is valid. We just
631  // need to additionally check that the value's attribute type is consistent
632  // with the result type.
633  return verifyConstantType(*this, getValueAttr(), getType());
634 }
635 
636 bool spirv::ConstantOp::isBuildableWith(Type type) {
637  // Must be valid SPIR-V type first.
638  if (!llvm::isa<spirv::SPIRVType>(type))
639  return false;
640 
641  if (isa<SPIRVDialect>(type.getDialect())) {
642  // TODO: support constant struct
643  return llvm::isa<spirv::ArrayType>(type);
644  }
645 
646  return true;
647 }
648 
649 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
650  OpBuilder &builder) {
651  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
652  unsigned width = intType.getWidth();
653  if (width == 1)
654  return spirv::ConstantOp::create(builder, loc, type,
655  builder.getBoolAttr(false));
656  return spirv::ConstantOp::create(
657  builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
658  }
659  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
660  return spirv::ConstantOp::create(builder, loc, type,
661  builder.getFloatAttr(floatType, 0.0));
662  }
663  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
664  Type elemType = vectorType.getElementType();
665  if (llvm::isa<IntegerType>(elemType)) {
666  return spirv::ConstantOp::create(
667  builder, loc, type,
668  DenseElementsAttr::get(vectorType,
669  IntegerAttr::get(elemType, 0).getValue()));
670  }
671  if (llvm::isa<FloatType>(elemType)) {
672  return spirv::ConstantOp::create(
673  builder, loc, type,
674  DenseFPElementsAttr::get(vectorType,
675  FloatAttr::get(elemType, 0.0).getValue()));
676  }
677  }
678 
679  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
680 }
681 
682 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
683  OpBuilder &builder) {
684  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
685  unsigned width = intType.getWidth();
686  if (width == 1)
687  return spirv::ConstantOp::create(builder, loc, type,
688  builder.getBoolAttr(true));
689  return spirv::ConstantOp::create(
690  builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
691  }
692  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
693  return spirv::ConstantOp::create(builder, loc, type,
694  builder.getFloatAttr(floatType, 1.0));
695  }
696  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
697  Type elemType = vectorType.getElementType();
698  if (llvm::isa<IntegerType>(elemType)) {
699  return spirv::ConstantOp::create(
700  builder, loc, type,
701  DenseElementsAttr::get(vectorType,
702  IntegerAttr::get(elemType, 1).getValue()));
703  }
704  if (llvm::isa<FloatType>(elemType)) {
705  return spirv::ConstantOp::create(
706  builder, loc, type,
707  DenseFPElementsAttr::get(vectorType,
708  FloatAttr::get(elemType, 1.0).getValue()));
709  }
710  }
711 
712  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
713 }
714 
715 void mlir::spirv::ConstantOp::getAsmResultNames(
716  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
717  Type type = getType();
718 
719  SmallString<32> specialNameBuffer;
720  llvm::raw_svector_ostream specialName(specialNameBuffer);
721  specialName << "cst";
722 
723  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
724 
725  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
726  assert(intTy);
727 
728  if (intTy.getWidth() == 1) {
729  return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
730  }
731 
732  if (intTy.isSignless()) {
733  specialName << intCst.getInt();
734  } else if (intTy.isUnsigned()) {
735  specialName << intCst.getUInt();
736  } else {
737  specialName << intCst.getSInt();
738  }
739  }
740 
741  if (intTy || llvm::isa<FloatType>(type)) {
742  specialName << '_' << type;
743  }
744 
745  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
746  specialName << "_vec_";
747  specialName << vecType.getDimSize(0);
748 
749  Type elementType = vecType.getElementType();
750 
751  if (llvm::isa<IntegerType>(elementType) ||
752  llvm::isa<FloatType>(elementType)) {
753  specialName << "x" << elementType;
754  }
755  }
756 
757  setNameFn(getResult(), specialName.str());
758 }
759 
760 void mlir::spirv::AddressOfOp::getAsmResultNames(
761  llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
762  SmallString<32> specialNameBuffer;
763  llvm::raw_svector_ostream specialName(specialNameBuffer);
764  specialName << getVariable() << "_addr";
765  setNameFn(getResult(), specialName.str());
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // spirv.EXTConstantCompositeReplicate
770 //===----------------------------------------------------------------------===//
771 
772 // Returns type of attribute. In case of a TypedAttr this will simply return
773 // the type. But for an ArrayAttr which is untyped and can be multidimensional
774 // it creates the ArrayType recursively.
776  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
777  return typedAttr.getType();
778  }
779 
780  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
781  return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
782  }
783 
784  return nullptr;
785 }
786 
788  Type valueType = getValueType(getValue());
789  if (!valueType)
790  return emitError("unknown value attribute type");
791 
792  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
793  if (!compositeType)
794  return emitError("result type is not a composite type");
795 
796  Type compositeElementType = compositeType.getElementType(0);
797 
798  SmallVector<Type, 3> possibleTypes = {compositeElementType};
799  while (auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
800  compositeElementType = type.getElementType(0);
801  possibleTypes.push_back(compositeElementType);
802  }
803 
804  if (!is_contained(possibleTypes, valueType)) {
805  return emitError("expected value attribute type ")
806  << interleaved(possibleTypes, " or ") << ", but got: " << valueType;
807  }
808 
809  return success();
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // spirv.ControlBarrierOp
814 //===----------------------------------------------------------------------===//
815 
816 LogicalResult spirv::ControlBarrierOp::verify() {
817  return verifyMemorySemantics(getOperation(), getMemorySemantics());
818 }
819 
820 //===----------------------------------------------------------------------===//
821 // spirv.EntryPoint
822 //===----------------------------------------------------------------------===//
823 
824 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
825  spirv::ExecutionModel executionModel,
826  spirv::FuncOp function,
827  ArrayRef<Attribute> interfaceVars) {
828  build(builder, state,
829  spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
830  SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
831 }
832 
833 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
834  OperationState &result) {
835  spirv::ExecutionModel execModel;
836  SmallVector<Attribute, 4> interfaceVars;
837 
839  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
840  parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
841  return failure();
842  }
843 
844  if (!parser.parseOptionalComma()) {
845  // Parse the interface variables
846  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
847  // The name of the interface variable attribute isnt important
848  FlatSymbolRefAttr var;
849  NamedAttrList attrs;
850  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
851  return failure();
852  interfaceVars.push_back(var);
853  return success();
854  }))
855  return failure();
856  }
857  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
858  parser.getBuilder().getArrayAttr(interfaceVars));
859  return success();
860 }
861 
863  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
864  printer.printSymbolName(getFn());
865  auto interfaceVars = getInterface().getValue();
866  if (!interfaceVars.empty())
867  printer << ", " << llvm::interleaved(interfaceVars);
868 }
869 
870 LogicalResult spirv::EntryPointOp::verify() {
871  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
872  // verification.
873  return success();
874 }
875 
876 //===----------------------------------------------------------------------===//
877 // spirv.ExecutionMode
878 //===----------------------------------------------------------------------===//
879 
880 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
881  spirv::FuncOp function,
882  spirv::ExecutionMode executionMode,
883  ArrayRef<int32_t> params) {
884  build(builder, state, SymbolRefAttr::get(function),
885  spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
886  builder.getI32ArrayAttr(params));
887 }
888 
889 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
890  OperationState &result) {
891  spirv::ExecutionMode execMode;
892  Attribute fn;
893  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
894  parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
895  return failure();
896  }
897 
899  Type i32Type = parser.getBuilder().getIntegerType(32);
900  while (!parser.parseOptionalComma()) {
901  NamedAttrList attr;
902  Attribute value;
903  if (parser.parseAttribute(value, i32Type, "value", attr)) {
904  return failure();
905  }
906  values.push_back(llvm::cast<IntegerAttr>(value).getInt());
907  }
908  StringRef valuesAttrName =
909  spirv::ExecutionModeOp::getValuesAttrName(result.name);
910  result.addAttribute(valuesAttrName,
911  parser.getBuilder().getI32ArrayAttr(values));
912  return success();
913 }
914 
916  printer << " ";
917  printer.printSymbolName(getFn());
918  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
919  ArrayAttr values = this->getValues();
920  if (!values.empty())
921  printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // spirv.func
926 //===----------------------------------------------------------------------===//
927 
928 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
930  SmallVector<DictionaryAttr> resultAttrs;
931  SmallVector<Type> resultTypes;
932  auto &builder = parser.getBuilder();
933 
934  // Parse the name as a symbol.
935  StringAttr nameAttr;
936  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
937  result.attributes))
938  return failure();
939 
940  // Parse the function signature.
941  bool isVariadic = false;
943  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
944  resultAttrs))
945  return failure();
946 
947  SmallVector<Type> argTypes;
948  for (auto &arg : entryArgs)
949  argTypes.push_back(arg.type);
950  auto fnType = builder.getFunctionType(argTypes, resultTypes);
951  result.addAttribute(getFunctionTypeAttrName(result.name),
952  TypeAttr::get(fnType));
953 
954  // Parse the optional function control keyword.
955  spirv::FunctionControl fnControl;
956  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
957  return failure();
958 
959  // If additional attributes are present, parse them.
960  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
961  return failure();
962 
963  // Add the attributes to the function arguments.
964  assert(resultAttrs.size() == resultTypes.size());
966  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
967  getResAttrsAttrName(result.name));
968 
969  // Parse the optional function body.
970  auto *body = result.addRegion();
971  OptionalParseResult parseResult =
972  parser.parseOptionalRegion(*body, entryArgs);
973  return failure(parseResult.has_value() && failed(*parseResult));
974 }
975 
976 void spirv::FuncOp::print(OpAsmPrinter &printer) {
977  // Print function name, signature, and control.
978  printer << " ";
979  printer.printSymbolName(getSymName());
980  auto fnType = getFunctionType();
982  printer, *this, fnType.getInputs(),
983  /*isVariadic=*/false, fnType.getResults());
984  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
985  << "\"";
987  printer, *this,
988  {spirv::attributeName<spirv::FunctionControl>(),
989  getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
990  getFunctionControlAttrName()});
991 
992  // Print the body if this is not an external function.
993  Region &body = this->getBody();
994  if (!body.empty()) {
995  printer << ' ';
996  printer.printRegion(body, /*printEntryBlockArgs=*/false,
997  /*printBlockTerminators=*/true);
998  }
999 }
1000 
1001 LogicalResult spirv::FuncOp::verifyType() {
1002  FunctionType fnType = getFunctionType();
1003  if (fnType.getNumResults() > 1)
1004  return emitOpError("cannot have more than one result");
1005 
1006  auto hasDecorationAttr = [&](spirv::Decoration decoration,
1007  unsigned argIndex) {
1008  auto func = llvm::cast<FunctionOpInterface>(getOperation());
1009  for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
1010  if (argAttr.getName() != spirv::DecorationAttr::name)
1011  continue;
1012  if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1013  return decAttr.getValue() == decoration;
1014  }
1015  return false;
1016  };
1017 
1018  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1019  Type param = fnType.getInputs()[i];
1020  auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1021  if (!inputPtrType)
1022  continue;
1023 
1024  auto pointeePtrType =
1025  dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1026  if (pointeePtrType) {
1027  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1028  // > If an OpFunctionParameter is a pointer (or contains a pointer)
1029  // > and the type it points to is a pointer in the PhysicalStorageBuffer
1030  // > storage class, the function parameter must be decorated with exactly
1031  // > one of AliasedPointer or RestrictPointer.
1032  if (pointeePtrType.getStorageClass() !=
1033  spirv::StorageClass::PhysicalStorageBuffer)
1034  continue;
1035 
1036  bool hasAliasedPtr =
1037  hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1038  bool hasRestrictPtr =
1039  hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1040  if (!hasAliasedPtr && !hasRestrictPtr)
1041  return emitOpError()
1042  << "with a pointer points to a physical buffer pointer must "
1043  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1044  continue;
1045  }
1046  // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
1047  // > If an OpFunctionParameter is a pointer (or contains a pointer) in
1048  // > the PhysicalStorageBuffer storage class, the function parameter must
1049  // > be decorated with exactly one of Aliased or Restrict.
1050  if (auto pointeeArrayType =
1051  dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1052  pointeePtrType =
1053  dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1054  } else {
1055  pointeePtrType = inputPtrType;
1056  }
1057 
1058  if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1059  spirv::StorageClass::PhysicalStorageBuffer)
1060  continue;
1061 
1062  bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1063  bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1064  if (!hasAliased && !hasRestrict)
1065  return emitOpError() << "with physical buffer pointer must be decorated "
1066  "either 'Aliased' or 'Restrict'";
1067  }
1068 
1069  return success();
1070 }
1071 
1072 LogicalResult spirv::FuncOp::verifyBody() {
1073  FunctionType fnType = getFunctionType();
1074  if (!isExternal()) {
1075  Block &entryBlock = front();
1076 
1077  unsigned numArguments = this->getNumArguments();
1078  if (entryBlock.getNumArguments() != numArguments)
1079  return emitOpError("entry block must have ")
1080  << numArguments << " arguments to match function signature";
1081 
1082  for (auto [index, fnArgType, blockArgType] :
1083  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
1084  if (blockArgType != fnArgType) {
1085  return emitOpError("type of entry block argument #")
1086  << index << '(' << blockArgType
1087  << ") must match the type of the corresponding argument in "
1088  << "function signature(" << fnArgType << ')';
1089  }
1090  }
1091  }
1092 
1093  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1094  if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1095  if (fnType.getNumResults() != 0)
1096  return retOp.emitOpError("cannot be used in functions returning value");
1097  } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1098  if (fnType.getNumResults() != 1)
1099  return retOp.emitOpError(
1100  "returns 1 value but enclosing function requires ")
1101  << fnType.getNumResults() << " results";
1102 
1103  auto retOperandType = retOp.getValue().getType();
1104  auto fnResultType = fnType.getResult(0);
1105  if (retOperandType != fnResultType)
1106  return retOp.emitOpError(" return value's type (")
1107  << retOperandType << ") mismatch with function's result type ("
1108  << fnResultType << ")";
1109  }
1110  return WalkResult::advance();
1111  });
1112 
1113  // TODO: verify other bits like linkage type.
1114 
1115  return failure(walkResult.wasInterrupted());
1116 }
1117 
1118 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1119  StringRef name, FunctionType type,
1120  spirv::FunctionControl control,
1121  ArrayRef<NamedAttribute> attrs) {
1122  state.addAttribute(SymbolTable::getSymbolAttrName(),
1123  builder.getStringAttr(name));
1124  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1125  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1126  builder.getAttr<spirv::FunctionControlAttr>(control));
1127  state.attributes.append(attrs.begin(), attrs.end());
1128  state.addRegion();
1129 }
1130 
1131 //===----------------------------------------------------------------------===//
1132 // spirv.GLFClampOp
1133 //===----------------------------------------------------------------------===//
1134 
1135 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
1136  OperationState &result) {
1137  return parseOneResultSameOperandTypeOp(parser, result);
1138 }
1140 
1141 //===----------------------------------------------------------------------===//
1142 // spirv.GLUClampOp
1143 //===----------------------------------------------------------------------===//
1144 
1145 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
1146  OperationState &result) {
1147  return parseOneResultSameOperandTypeOp(parser, result);
1148 }
1150 
1151 //===----------------------------------------------------------------------===//
1152 // spirv.GLSClampOp
1153 //===----------------------------------------------------------------------===//
1154 
1155 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
1156  OperationState &result) {
1157  return parseOneResultSameOperandTypeOp(parser, result);
1158 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // spirv.GLFmaOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
1166  return parseOneResultSameOperandTypeOp(parser, result);
1167 }
1168 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // spirv.GlobalVariable
1172 //===----------------------------------------------------------------------===//
1173 
1174 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1175  Type type, StringRef name,
1176  unsigned descriptorSet, unsigned binding) {
1177  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1178  state.addAttribute(
1179  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1180  builder.getI32IntegerAttr(descriptorSet));
1181  state.addAttribute(
1182  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1183  builder.getI32IntegerAttr(binding));
1184 }
1185 
1186 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1187  Type type, StringRef name,
1188  spirv::BuiltIn builtin) {
1189  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
1190  state.addAttribute(
1191  spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1192  builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1193 }
1194 
1195 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
1196  OperationState &result) {
1197  // Parse variable name.
1198  StringAttr nameAttr;
1199  StringRef initializerAttrName =
1200  spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1201  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1202  result.attributes)) {
1203  return failure();
1204  }
1205 
1206  // Parse optional initializer
1207  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
1208  FlatSymbolRefAttr initSymbol;
1209  if (parser.parseLParen() ||
1210  parser.parseAttribute(initSymbol, Type(), initializerAttrName,
1211  result.attributes) ||
1212  parser.parseRParen())
1213  return failure();
1214  }
1215 
1216  if (parseVariableDecorations(parser, result)) {
1217  return failure();
1218  }
1219 
1220  Type type;
1221  StringRef typeAttrName =
1222  spirv::GlobalVariableOp::getTypeAttrName(result.name);
1223  auto loc = parser.getCurrentLocation();
1224  if (parser.parseColonType(type)) {
1225  return failure();
1226  }
1227  if (!llvm::isa<spirv::PointerType>(type)) {
1228  return parser.emitError(loc, "expected spirv.ptr type");
1229  }
1230  result.addAttribute(typeAttrName, TypeAttr::get(type));
1231 
1232  return success();
1233 }
1234 
1236  SmallVector<StringRef, 4> elidedAttrs{
1237  spirv::attributeName<spirv::StorageClass>()};
1238 
1239  // Print variable name.
1240  printer << ' ';
1241  printer.printSymbolName(getSymName());
1242  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1243 
1244  StringRef initializerAttrName = this->getInitializerAttrName();
1245  // Print optional initializer
1246  if (auto initializer = this->getInitializer()) {
1247  printer << " " << initializerAttrName << '(';
1248  printer.printSymbolName(*initializer);
1249  printer << ')';
1250  elidedAttrs.push_back(initializerAttrName);
1251  }
1252 
1253  StringRef typeAttrName = this->getTypeAttrName();
1254  elidedAttrs.push_back(typeAttrName);
1255  spirv::printVariableDecorations(*this, printer, elidedAttrs);
1256  printer << " : " << getType();
1257 }
1258 
1259 LogicalResult spirv::GlobalVariableOp::verify() {
1260  if (!llvm::isa<spirv::PointerType>(getType()))
1261  return emitOpError("result must be of a !spv.ptr type");
1262 
1263  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1264  // object. It cannot be Generic. It must be the same as the Storage Class
1265  // operand of the Result Type."
1266  // Also, Function storage class is reserved by spirv.Variable.
1267  auto storageClass = this->storageClass();
1268  if (storageClass == spirv::StorageClass::Generic ||
1269  storageClass == spirv::StorageClass::Function) {
1270  return emitOpError("storage class cannot be '")
1271  << stringifyStorageClass(storageClass) << "'";
1272  }
1273 
1274  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
1275  this->getInitializerAttrName())) {
1277  (*this)->getParentOp(), init.getAttr());
1278  // TODO: Currently only variable initialization with specialization
1279  // constants is supported. There could be normal constants in the module
1280  // scope as well.
1281  //
1282  // In the current setup we also cannot initialize one global variable with
1283  // another. The problem is that if we try to initialize pointer of type X
1284  // with another pointer type, the validator fails because it expects the
1285  // variable to be initialized to be type X, not pointer to X. Now
1286  // `spirv.GlobalVariable` only allows pointer type, so in the current design
1287  // we cannot initialize one `spirv.GlobalVariable` with another.
1288  if (!initOp ||
1289  !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1290  return emitOpError("initializer must be result of a "
1291  "spirv.SpecConstant or "
1292  "spirv.SpecConstantCompositeOp op");
1293  }
1294  }
1295 
1296  return success();
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // spirv.INTEL.SubgroupBlockRead
1301 //===----------------------------------------------------------------------===//
1302 
1304  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1305  return failure();
1306 
1307  return success();
1308 }
1309 
1310 //===----------------------------------------------------------------------===//
1311 // spirv.INTEL.SubgroupBlockWrite
1312 //===----------------------------------------------------------------------===//
1313 
1315  OperationState &result) {
1316  // Parse the storage class specification
1317  spirv::StorageClass storageClass;
1319  auto loc = parser.getCurrentLocation();
1320  Type elementType;
1321  if (parseEnumStrAttr(storageClass, parser) ||
1322  parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
1323  parser.parseType(elementType)) {
1324  return failure();
1325  }
1326 
1327  auto ptrType = spirv::PointerType::get(elementType, storageClass);
1328  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1329  ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
1330 
1331  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1332  result.operands)) {
1333  return failure();
1334  }
1335  return success();
1336 }
1337 
1339  printer << " " << getPtr() << ", " << getValue() << " : "
1340  << getValue().getType();
1341 }
1342 
1344  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
1345  return failure();
1346 
1347  return success();
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // spirv.IAddCarryOp
1352 //===----------------------------------------------------------------------===//
1353 
1354 LogicalResult spirv::IAddCarryOp::verify() {
1356 }
1357 
1358 ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
1359  OperationState &result) {
1361 }
1362 
1363 void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
1364  ::printArithmeticExtendedBinaryOp(*this, printer);
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // spirv.ISubBorrowOp
1369 //===----------------------------------------------------------------------===//
1370 
1371 LogicalResult spirv::ISubBorrowOp::verify() {
1373 }
1374 
1375 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
1376  OperationState &result) {
1378 }
1379 
1381  ::printArithmeticExtendedBinaryOp(*this, printer);
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // spirv.SMulExtended
1386 //===----------------------------------------------------------------------===//
1387 
1388 LogicalResult spirv::SMulExtendedOp::verify() {
1390 }
1391 
1392 ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
1393  OperationState &result) {
1395 }
1396 
1398  ::printArithmeticExtendedBinaryOp(*this, printer);
1399 }
1400 
1401 //===----------------------------------------------------------------------===//
1402 // spirv.UMulExtended
1403 //===----------------------------------------------------------------------===//
1404 
1405 LogicalResult spirv::UMulExtendedOp::verify() {
1407 }
1408 
1409 ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
1410  OperationState &result) {
1412 }
1413 
1415  ::printArithmeticExtendedBinaryOp(*this, printer);
1416 }
1417 
1418 //===----------------------------------------------------------------------===//
1419 // spirv.MemoryBarrierOp
1420 //===----------------------------------------------------------------------===//
1421 
1422 LogicalResult spirv::MemoryBarrierOp::verify() {
1423  return verifyMemorySemantics(getOperation(), getMemorySemantics());
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // spirv.module
1428 //===----------------------------------------------------------------------===//
1429 
1430 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1431  std::optional<StringRef> name) {
1432  OpBuilder::InsertionGuard guard(builder);
1433  builder.createBlock(state.addRegion());
1434  if (name) {
1435  state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
1436  builder.getStringAttr(*name));
1437  }
1438 }
1439 
1440 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
1441  spirv::AddressingModel addressingModel,
1442  spirv::MemoryModel memoryModel,
1443  std::optional<VerCapExtAttr> vceTriple,
1444  std::optional<StringRef> name) {
1445  state.addAttribute(
1446  "addressing_model",
1447  builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
1448  state.addAttribute("memory_model",
1449  builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
1450  OpBuilder::InsertionGuard guard(builder);
1451  builder.createBlock(state.addRegion());
1452  if (vceTriple)
1453  state.addAttribute(getVCETripleAttrName(), *vceTriple);
1454  if (name)
1455  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
1456  builder.getStringAttr(*name));
1457 }
1458 
1459 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
1460  OperationState &result) {
1461  Region *body = result.addRegion();
1462 
1463  // If the name is present, parse it.
1464  StringAttr nameAttr;
1465  (void)parser.parseOptionalSymbolName(
1466  nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
1467 
1468  // Parse attributes
1469  spirv::AddressingModel addrModel;
1470  spirv::MemoryModel memoryModel;
1471  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1472  result) ||
1473  spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1474  result))
1475  return failure();
1476 
1477  if (succeeded(parser.parseOptionalKeyword("requires"))) {
1478  spirv::VerCapExtAttr vceTriple;
1479  if (parser.parseAttribute(vceTriple,
1480  spirv::ModuleOp::getVCETripleAttrName(),
1481  result.attributes))
1482  return failure();
1483  }
1484 
1485  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
1486  parser.parseRegion(*body, /*arguments=*/{}))
1487  return failure();
1488 
1489  // Make sure we have at least one block.
1490  if (body->empty())
1491  body->push_back(new Block());
1492 
1493  return success();
1494 }
1495 
1496 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
1497  if (std::optional<StringRef> name = getName()) {
1498  printer << ' ';
1499  printer.printSymbolName(*name);
1500  }
1501 
1502  SmallVector<StringRef, 2> elidedAttrs;
1503 
1504  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1505  << spirv::stringifyMemoryModel(getMemoryModel());
1506  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1507  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1508  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1510 
1511  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1512  printer << " requires " << *triple;
1513  elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1514  }
1515 
1516  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
1517  printer << ' ';
1518  printer.printRegion(getRegion());
1519 }
1520 
1521 LogicalResult spirv::ModuleOp::verifyRegions() {
1522  Dialect *dialect = (*this)->getDialect();
1524  entryPoints;
1525  mlir::SymbolTable table(*this);
1526 
1527  for (auto &op : *getBody()) {
1528  if (op.getDialect() != dialect)
1529  return op.emitError("'spirv.module' can only contain spirv.* ops");
1530 
1531  // For EntryPoint op, check that the function and execution model is not
1532  // duplicated in EntryPointOps. Also verify that the interface specified
1533  // comes from globalVariables here to make this check cheaper.
1534  if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1535  auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1536  if (!funcOp) {
1537  return entryPointOp.emitError("function '")
1538  << entryPointOp.getFn() << "' not found in 'spirv.module'";
1539  }
1540  if (auto interface = entryPointOp.getInterface()) {
1541  for (Attribute varRef : interface) {
1542  auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1543  if (!varSymRef) {
1544  return entryPointOp.emitError(
1545  "expected symbol reference for interface "
1546  "specification instead of '")
1547  << varRef;
1548  }
1549  auto variableOp =
1550  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1551  if (!variableOp) {
1552  return entryPointOp.emitError("expected spirv.GlobalVariable "
1553  "symbol reference instead of'")
1554  << varSymRef << "'";
1555  }
1556  }
1557  }
1558 
1559  auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1560  funcOp, entryPointOp.getExecutionModel());
1561  if (!entryPoints.try_emplace(key, entryPointOp).second)
1562  return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1563  } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1564  // If the function is external and does not have 'Import'
1565  // linkage_attributes(LinkageAttributes), throw an error. 'Import'
1566  // LinkageAttributes is used to import external functions.
1567  auto linkageAttr = funcOp.getLinkageAttributes();
1568  auto hasImportLinkage =
1569  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1570  spirv::LinkageType::Import);
1571  if (funcOp.isExternal() && !hasImportLinkage)
1572  return op.emitError(
1573  "'spirv.module' cannot contain external functions "
1574  "without 'Import' linkage_attributes (LinkageAttributes)");
1575 
1576  // TODO: move this check to spirv.func.
1577  for (auto &block : funcOp)
1578  for (auto &op : block) {
1579  if (op.getDialect() != dialect)
1580  return op.emitError(
1581  "functions in 'spirv.module' can only contain spirv.* ops");
1582  }
1583  }
1584  }
1585 
1586  return success();
1587 }
1588 
1589 //===----------------------------------------------------------------------===//
1590 // spirv.mlir.referenceof
1591 //===----------------------------------------------------------------------===//
1592 
1593 LogicalResult spirv::ReferenceOfOp::verify() {
1594  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
1595  (*this)->getParentOp(), getSpecConstAttr());
1596  Type constType;
1597 
1598  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1599  if (specConstOp)
1600  constType = specConstOp.getDefaultValue().getType();
1601 
1602  auto specConstCompositeOp =
1603  dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1604  if (specConstCompositeOp)
1605  constType = specConstCompositeOp.getType();
1606 
1607  if (!specConstOp && !specConstCompositeOp)
1608  return emitOpError(
1609  "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1610 
1611  if (getReference().getType() != constType)
1612  return emitOpError("result type mismatch with the referenced "
1613  "specialization constant's type");
1614 
1615  return success();
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // spirv.SpecConstant
1620 //===----------------------------------------------------------------------===//
1621 
1622 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
1623  OperationState &result) {
1624  StringAttr nameAttr;
1625  Attribute valueAttr;
1626  StringRef defaultValueAttrName =
1627  spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1628 
1629  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1630  result.attributes))
1631  return failure();
1632 
1633  // Parse optional spec_id.
1634  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
1635  IntegerAttr specIdAttr;
1636  if (parser.parseLParen() ||
1637  parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
1638  parser.parseRParen())
1639  return failure();
1640  }
1641 
1642  if (parser.parseEqual() ||
1643  parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
1644  return failure();
1645 
1646  return success();
1647 }
1648 
1650  printer << ' ';
1651  printer.printSymbolName(getSymName());
1652  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1653  printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1654  printer << " = " << getDefaultValue();
1655 }
1656 
1657 LogicalResult spirv::SpecConstantOp::verify() {
1658  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
1659  if (specID.getValue().isNegative())
1660  return emitOpError("SpecId cannot be negative");
1661 
1662  auto value = getDefaultValue();
1663  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1664  // Make sure bitwidth is allowed.
1665  if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1666  return emitOpError("default value bitwidth disallowed");
1667  return success();
1668  }
1669  return emitOpError(
1670  "default value can only be a bool, integer, or float scalar");
1671 }
1672 
1673 //===----------------------------------------------------------------------===//
1674 // spirv.VectorShuffle
1675 //===----------------------------------------------------------------------===//
1676 
1677 LogicalResult spirv::VectorShuffleOp::verify() {
1678  VectorType resultType = llvm::cast<VectorType>(getType());
1679 
1680  size_t numResultElements = resultType.getNumElements();
1681  if (numResultElements != getComponents().size())
1682  return emitOpError("result type element count (")
1683  << numResultElements
1684  << ") mismatch with the number of component selectors ("
1685  << getComponents().size() << ")";
1686 
1687  size_t totalSrcElements =
1688  llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1689  llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1690 
1691  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1692  uint32_t index = selector.getZExtValue();
1693  if (index >= totalSrcElements &&
1694  index != std::numeric_limits<uint32_t>().max())
1695  return emitOpError("component selector ")
1696  << index << " out of range: expected to be in [0, "
1697  << totalSrcElements << ") or 0xffffffff";
1698  }
1699  return success();
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // spirv.MatrixTimesScalar
1704 //===----------------------------------------------------------------------===//
1705 
1706 LogicalResult spirv::MatrixTimesScalarOp::verify() {
1707  Type elementType =
1708  llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1710  [](auto matrixType) { return matrixType.getElementType(); })
1711  .Default(nullptr);
1712 
1713  assert(elementType && "Unhandled type");
1714 
1715  // Check that the scalar type is the same as the matrix element type.
1716  if (getScalar().getType() != elementType)
1717  return emitOpError("input matrix components' type and scaling value must "
1718  "have the same type");
1719 
1720  return success();
1721 }
1722 
1723 //===----------------------------------------------------------------------===//
1724 // spirv.Transpose
1725 //===----------------------------------------------------------------------===//
1726 
1727 LogicalResult spirv::TransposeOp::verify() {
1728  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1729  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1730 
1731  // Verify that the input and output matrices have correct shapes.
1732  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1733  return emitError("input matrix rows count must be equal to "
1734  "output matrix columns count");
1735 
1736  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1737  return emitError("input matrix columns count must be equal to "
1738  "output matrix rows count");
1739 
1740  // Verify that the input and output matrices have the same component type
1741  if (inputMatrix.getElementType() != resultMatrix.getElementType())
1742  return emitError("input and output matrices must have the same "
1743  "component type");
1744 
1745  return success();
1746 }
1747 
1748 //===----------------------------------------------------------------------===//
1749 // spirv.MatrixTimesVector
1750 //===----------------------------------------------------------------------===//
1751 
1752 LogicalResult spirv::MatrixTimesVectorOp::verify() {
1753  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1754  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1755  auto resultType = llvm::cast<VectorType>(getType());
1756 
1757  if (matrixType.getNumColumns() != vectorType.getNumElements())
1758  return emitOpError("matrix columns (")
1759  << matrixType.getNumColumns() << ") must match vector operand size ("
1760  << vectorType.getNumElements() << ")";
1761 
1762  if (resultType.getNumElements() != matrixType.getNumRows())
1763  return emitOpError("result size (")
1764  << resultType.getNumElements() << ") must match the matrix rows ("
1765  << matrixType.getNumRows() << ")";
1766 
1767  if (matrixType.getElementType() != resultType.getElementType())
1768  return emitOpError("matrix and result element types must match");
1769 
1770  return success();
1771 }
1772 
1773 //===----------------------------------------------------------------------===//
1774 // spirv.VectorTimesMatrix
1775 //===----------------------------------------------------------------------===//
1776 
1777 LogicalResult spirv::VectorTimesMatrixOp::verify() {
1778  auto vectorType = llvm::cast<VectorType>(getVector().getType());
1779  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1780  auto resultType = llvm::cast<VectorType>(getType());
1781 
1782  if (matrixType.getNumRows() != vectorType.getNumElements())
1783  return emitOpError("number of components in vector must equal the number "
1784  "of components in each column in matrix");
1785 
1786  if (resultType.getNumElements() != matrixType.getNumColumns())
1787  return emitOpError("number of columns in matrix must equal the number of "
1788  "components in result");
1789 
1790  if (matrixType.getElementType() != resultType.getElementType())
1791  return emitOpError("matrix must be a matrix with the same component type "
1792  "as the component type in result");
1793 
1794  return success();
1795 }
1796 
1797 //===----------------------------------------------------------------------===//
1798 // spirv.MatrixTimesMatrix
1799 //===----------------------------------------------------------------------===//
1800 
1801 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1802  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1803  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1804  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1805 
1806  // left matrix columns' count and right matrix rows' count must be equal
1807  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1808  return emitError("left matrix columns' count must be equal to "
1809  "the right matrix rows' count");
1810 
1811  // right and result matrices columns' count must be the same
1812  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1813  return emitError(
1814  "right and result matrices must have equal columns' count");
1815 
1816  // right and result matrices component type must be the same
1817  if (rightMatrix.getElementType() != resultMatrix.getElementType())
1818  return emitError("right and result matrices' component type must"
1819  " be the same");
1820 
1821  // left and result matrices component type must be the same
1822  if (leftMatrix.getElementType() != resultMatrix.getElementType())
1823  return emitError("left and result matrices' component type"
1824  " must be the same");
1825 
1826  // left and result matrices rows count must be the same
1827  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1828  return emitError("left and result matrices must have equal rows' count");
1829 
1830  return success();
1831 }
1832 
1833 //===----------------------------------------------------------------------===//
1834 // spirv.SpecConstantComposite
1835 //===----------------------------------------------------------------------===//
1836 
1838  OperationState &result) {
1839 
1840  StringAttr compositeName;
1841  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1842  result.attributes))
1843  return failure();
1844 
1845  if (parser.parseLParen())
1846  return failure();
1847 
1848  SmallVector<Attribute, 4> constituents;
1849 
1850  do {
1851  // The name of the constituent attribute isn't important
1852  const char *attrName = "spec_const";
1853  FlatSymbolRefAttr specConstRef;
1854  NamedAttrList attrs;
1855 
1856  if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
1857  return failure();
1858 
1859  constituents.push_back(specConstRef);
1860  } while (!parser.parseOptionalComma());
1861 
1862  if (parser.parseRParen())
1863  return failure();
1864 
1865  StringAttr compositeSpecConstituentsName =
1866  spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1867  result.addAttribute(compositeSpecConstituentsName,
1868  parser.getBuilder().getArrayAttr(constituents));
1869 
1870  Type type;
1871  if (parser.parseColonType(type))
1872  return failure();
1873 
1874  StringAttr typeAttrName =
1875  spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1876  result.addAttribute(typeAttrName, TypeAttr::get(type));
1877 
1878  return success();
1879 }
1880 
1882  printer << " ";
1883  printer.printSymbolName(getSymName());
1884  printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1885  << ") : " << getType();
1886 }
1887 
1889  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1890  auto constituents = this->getConstituents().getValue();
1891 
1892  if (!cType)
1893  return emitError("result type must be a composite type, but provided ")
1894  << getType();
1895 
1896  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1897  return emitError("unsupported composite type ") << cType;
1898  if (constituents.size() != cType.getNumElements())
1899  return emitError("has incorrect number of operands: expected ")
1900  << cType.getNumElements() << ", but provided "
1901  << constituents.size();
1902 
1903  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1904  auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1905 
1906  auto constituentSpecConstOp =
1907  dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
1908  (*this)->getParentOp(), constituent.getAttr()));
1909 
1910  if (constituentSpecConstOp.getDefaultValue().getType() !=
1911  cType.getElementType(index))
1912  return emitError("has incorrect types of operands: expected ")
1913  << cType.getElementType(index) << ", but provided "
1914  << constituentSpecConstOp.getDefaultValue().getType();
1915  }
1916 
1917  return success();
1918 }
1919 
1920 //===----------------------------------------------------------------------===//
1921 // spirv.EXTSpecConstantCompositeReplicateOp
1922 //===----------------------------------------------------------------------===//
1923 
1924 ParseResult
1926  OperationState &result) {
1927  StringAttr compositeName;
1928  FlatSymbolRefAttr specConstRef;
1929  const char *attrName = "spec_const";
1930  NamedAttrList attrs;
1931  Type type;
1932 
1933  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
1934  result.attributes) ||
1935  parser.parseLParen() ||
1936  parser.parseAttribute(specConstRef, Type(), attrName, attrs) ||
1937  parser.parseRParen() || parser.parseColonType(type))
1938  return failure();
1939 
1940  StringAttr compositeSpecConstituentName =
1941  spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1942  result.name);
1943  result.addAttribute(compositeSpecConstituentName, specConstRef);
1944 
1945  StringAttr typeAttrName =
1946  spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.name);
1947  result.addAttribute(typeAttrName, TypeAttr::get(type));
1948 
1949  return success();
1950 }
1951 
1953  printer << " ";
1954  printer.printSymbolName(getSymName());
1955  printer << " (" << this->getConstituent() << ") : " << getType();
1956 }
1957 
1959  auto compositeType = dyn_cast<spirv::CompositeType>(getType());
1960  if (!compositeType)
1961  return emitError("result type must be a composite type, but provided ")
1962  << getType();
1963 
1965  (*this)->getParentOp(), this->getConstituent());
1966  if (!constituentOp)
1967  return emitError(
1968  "splat spec constant reference defining constituent not found");
1969 
1970  auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1971  if (!constituentSpecConstOp)
1972  return emitError("constituent is not a spec constant");
1973 
1974  Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1975  Type compositeElementType = compositeType.getElementType(0);
1976  if (constituentType != compositeElementType)
1977  return emitError("constituent has incorrect type: expected ")
1978  << compositeElementType << ", but provided " << constituentType;
1979 
1980  return success();
1981 }
1982 
1983 //===----------------------------------------------------------------------===//
1984 // spirv.SpecConstantOperation
1985 //===----------------------------------------------------------------------===//
1986 
1988  OperationState &result) {
1989  Region *body = result.addRegion();
1990 
1991  if (parser.parseKeyword("wraps"))
1992  return failure();
1993 
1994  body->push_back(new Block);
1995  Block &block = body->back();
1996  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
1997 
1998  if (!wrappedOp)
1999  return failure();
2000 
2001  OpBuilder builder(parser.getContext());
2002  builder.setInsertionPointToEnd(&block);
2003  spirv::YieldOp::create(builder, wrappedOp->getLoc(), wrappedOp->getResult(0));
2004  result.location = wrappedOp->getLoc();
2005 
2006  result.addTypes(wrappedOp->getResult(0).getType());
2007 
2008  if (parser.parseOptionalAttrDict(result.attributes))
2009  return failure();
2010 
2011  return success();
2012 }
2013 
2015  printer << " wraps ";
2016  printer.printGenericOp(&getBody().front().front());
2017 }
2018 
2019 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2020  Block &block = getRegion().getBlocks().front();
2021 
2022  if (block.getOperations().size() != 2)
2023  return emitOpError("expected exactly 2 nested ops");
2024 
2025  Operation &enclosedOp = block.getOperations().front();
2026 
2028  return emitOpError("invalid enclosed op");
2029 
2030  for (auto operand : enclosedOp.getOperands())
2031  if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2032  spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2033  return emitOpError(
2034  "invalid operand, must be defined by a constant operation");
2035 
2036  return success();
2037 }
2038 
2039 //===----------------------------------------------------------------------===//
2040 // spirv.GL.FrexpStruct
2041 //===----------------------------------------------------------------------===//
2042 
2043 LogicalResult spirv::GLFrexpStructOp::verify() {
2044  spirv::StructType structTy =
2045  llvm::dyn_cast<spirv::StructType>(getResult().getType());
2046 
2047  if (structTy.getNumElements() != 2)
2048  return emitError("result type must be a struct type with two memebers");
2049 
2050  Type significandTy = structTy.getElementType(0);
2051  Type exponentTy = structTy.getElementType(1);
2052  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
2053  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
2054 
2055  Type operandTy = getOperand().getType();
2056  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
2057  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
2058 
2059  if (significandTy != operandTy)
2060  return emitError("member zero of the resulting struct type must be the "
2061  "same type as the operand");
2062 
2063  if (exponentVecTy) {
2064  IntegerType componentIntTy =
2065  llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
2066  if (!componentIntTy || componentIntTy.getWidth() != 32)
2067  return emitError("member one of the resulting struct type must"
2068  "be a scalar or vector of 32 bit integer type");
2069  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2070  return emitError("member one of the resulting struct type "
2071  "must be a scalar or vector of 32 bit integer type");
2072  }
2073 
2074  // Check that the two member types have the same number of components
2075  if (operandVecTy && exponentVecTy &&
2076  (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2077  return success();
2078 
2079  if (operandFTy && exponentIntTy)
2080  return success();
2081 
2082  return emitError("member one of the resulting struct type must have the same "
2083  "number of components as the operand type");
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // spirv.GL.Ldexp
2088 //===----------------------------------------------------------------------===//
2089 
2090 LogicalResult spirv::GLLdexpOp::verify() {
2091  Type significandType = getX().getType();
2092  Type exponentType = getExp().getType();
2093 
2094  if (llvm::isa<FloatType>(significandType) !=
2095  llvm::isa<IntegerType>(exponentType))
2096  return emitOpError("operands must both be scalars or vectors");
2097 
2098  auto getNumElements = [](Type type) -> unsigned {
2099  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2100  return vectorType.getNumElements();
2101  return 1;
2102  };
2103 
2104  if (getNumElements(significandType) != getNumElements(exponentType))
2105  return emitOpError("operands must have the same number of elements");
2106 
2107  return success();
2108 }
2109 
2110 //===----------------------------------------------------------------------===//
2111 // spirv.ShiftLeftLogicalOp
2112 //===----------------------------------------------------------------------===//
2113 
2114 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2115  return verifyShiftOp(*this);
2116 }
2117 
2118 //===----------------------------------------------------------------------===//
2119 // spirv.ShiftRightArithmeticOp
2120 //===----------------------------------------------------------------------===//
2121 
2122 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2123  return verifyShiftOp(*this);
2124 }
2125 
2126 //===----------------------------------------------------------------------===//
2127 // spirv.ShiftRightLogicalOp
2128 //===----------------------------------------------------------------------===//
2129 
2130 LogicalResult spirv::ShiftRightLogicalOp::verify() {
2131  return verifyShiftOp(*this);
2132 }
2133 
2134 //===----------------------------------------------------------------------===//
2135 // spirv.VectorTimesScalarOp
2136 //===----------------------------------------------------------------------===//
2137 
2138 LogicalResult spirv::VectorTimesScalarOp::verify() {
2139  if (getVector().getType() != getType())
2140  return emitOpError("vector operand and result type mismatch");
2141  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2142  if (getScalar().getType() != scalarType)
2143  return emitOpError("scalar operand and result element type match");
2144  return success();
2145 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:267
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
Definition: SPIRVOps.cpp:563
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
Definition: SPIRVOps.cpp:120
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
Definition: SPIRVOps.cpp:253
static LogicalResult verifyShiftOp(Operation *op)
Definition: SPIRVOps.cpp:299
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
Definition: SPIRVOps.cpp:169
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:186
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
Definition: SPIRVOps.cpp:149
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
Definition: SPIRVOps.cpp:291
ArrayRef< float > table
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:276
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
Definition: SPIRVOpTraits.h:33
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:158
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:451
SPIR-V struct type.
Definition: SPIRVTypes.h:251
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
static Type getValueType(Attribute attr)
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
Definition: SPIRVOps.cpp:69
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
Definition: SPIRVOps.cpp:93
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
Definition: SPIRVOps.cpp:49
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.