31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/InterleavedRange.h"
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
63 value = integerValueAttr.getSInt();
70 spirv::MemorySemantics memorySemantics) {
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
83 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
102 if (descriptorSet && binding) {
105 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
112 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
113 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
114 elidedAttrs.push_back(builtInName);
132 auto fnType = llvm::dyn_cast<FunctionType>(type);
134 parser.
emitError(loc,
"expected function type");
139 result.
addTypes(fnType.getResults());
150 assert(op->
getNumResults() == 1 &&
"op should have one result");
156 [&](
Type type) { return type != resultType; })) {
165 p <<
" : " << resultType;
168 template <
typename BlockReadWriteOpTy>
172 if (
auto valVecTy = llvm::dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
176 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
177 return op.emitOpError(
"mismatch in result type and pointer type");
188 if (indices.empty()) {
189 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
193 for (
auto index : indices) {
194 if (
auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
195 if (cType.hasCompileTimeKnownNumElements() &&
197 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
198 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
201 type = cType.getElementType(index);
203 emitErrorFn(
"cannot extract from non-composite type ")
204 << type <<
" with index " << index;
214 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
215 if (!indicesArrayAttr) {
216 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
219 if (indicesArrayAttr.empty()) {
220 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
225 for (
auto indexAttr : indicesArrayAttr) {
226 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
228 emitErrorFn(
"expected an 32-bit integer for index, but found '")
232 indexVals.push_back(indexIntAttr.getInt());
252 template <
typename ExtendedBinaryOp>
254 auto resultType = llvm::cast<spirv::StructType>(op.getType());
255 if (resultType.getNumElements() != 2)
256 return op.emitOpError(
"expected result struct type containing two members");
258 if (!llvm::all_equal({op.getOperand1().
getType(), op.getOperand2().getType(),
259 resultType.getElementType(0),
260 resultType.getElementType(1)}))
261 return op.emitOpError(
262 "expected all operand types and struct member types are the same");
279 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
280 if (!structType || structType.getNumElements() != 2)
281 return parser.
emitError(loc,
"expected spirv.struct type with two members");
301 return op->
emitError(
"expected the same type for the first operand and "
302 "result, but provided ")
314 spirv::GlobalVariableOp var) {
319 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
323 return emitOpError(
"expected spirv.GlobalVariable symbol");
325 if (getPointer().
getType() != varOp.getType()) {
327 "result type mismatch with the referenced global variable's type");
337 operand_range constituents = this->getConstituents();
345 auto coopElementType =
348 [](
auto coopType) {
return coopType.getElementType(); })
349 .Default([](
Type) {
return nullptr; });
352 if (coopElementType) {
353 if (constituents.size() != 1)
354 return emitOpError(
"has incorrect number of operands: expected ")
355 <<
"1, but provided " << constituents.size();
356 if (coopElementType != constituents.front().getType())
357 return emitOpError(
"operand type mismatch: expected operand type ")
358 << coopElementType <<
", but provided "
359 << constituents.front().getType();
364 auto cType = llvm::cast<spirv::CompositeType>(
getType());
365 if (constituents.size() == cType.getNumElements()) {
366 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367 if (constituents[index].
getType() != cType.getElementType(index)) {
368 return emitOpError(
"operand type mismatch: expected operand type ")
369 << cType.getElementType(index) <<
", but provided "
370 << constituents[index].getType();
377 auto resultType = llvm::dyn_cast<VectorType>(cType);
380 "expected to return a vector or cooperative matrix when the number of "
381 "constituents is less than what the result needs");
384 for (
Value component : constituents) {
385 if (!llvm::isa<VectorType>(component.getType()) &&
386 !component.getType().isIntOrFloat())
387 return emitOpError(
"operand type mismatch: expected operand to have "
388 "a scalar or vector type, but provided ")
389 << component.getType();
391 Type elementType = component.getType();
392 if (
auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
393 sizes.push_back(vectorType.getNumElements());
394 elementType = vectorType.getElementType();
399 if (elementType != resultType.getElementType())
400 return emitOpError(
"operand element type mismatch: expected to be ")
401 << resultType.getElementType() <<
", but provided " << elementType;
403 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
404 if (totalCount != cType.getNumElements())
405 return emitOpError(
"has incorrect number of operands: expected ")
406 << cType.getNumElements() <<
", but provided " << totalCount;
423 build(builder, state, elementType, composite, indexAttr);
430 StringRef indicesAttrName =
431 spirv::CompositeExtractOp::getIndicesAttrName(result.
name);
453 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
458 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
465 return emitOpError(
"invalid result type: expected ")
466 << resultType <<
" but provided " <<
getType();
480 build(builder, state, composite.
getType(),
object, composite, indexAttr);
486 Type objectType, compositeType;
488 StringRef indicesAttrName =
489 spirv::CompositeInsertOp::getIndicesAttrName(result.
name);
503 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
509 if (objectType != getObject().
getType()) {
510 return emitOpError(
"object operand type should be ")
511 << objectType <<
", but found " << getObject().getType();
515 return emitOpError(
"result type should be the same as "
516 "the composite type, but found ")
517 << getComposite().getType() <<
" vs " <<
getType();
524 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
525 <<
" : " << getObject().
getType() <<
" into "
526 << getComposite().getType();
536 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.
name);
541 if (
auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
542 type = typedAttr.getType();
543 if (llvm::isa<NoneType, TensorType>(type)) {
548 if (llvm::isa<TensorArmType>(type)) {
558 printer <<
' ' << getValue();
559 if (llvm::isa<spirv::ArrayType>(
getType()))
565 if (isa<spirv::CooperativeMatrixType>(opType)) {
566 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
567 if (!denseAttr || !denseAttr.isSplat())
568 return op.emitOpError(
"expected a splat dense attribute for cooperative "
569 "matrix constant, but found ")
572 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
573 auto valueType = llvm::cast<TypedAttr>(value).getType();
574 if (valueType != opType)
575 return op.emitOpError(
"result type (")
576 << opType <<
") does not match value type (" << valueType <<
")";
579 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
580 auto valueType = llvm::cast<TypedAttr>(value).getType();
581 if (valueType == opType)
583 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
584 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
586 return op.emitOpError(
"result or element type (")
587 << opType <<
") does not match value type (" << valueType
588 <<
"), must be the same or spirv.array";
590 int numElements = arrayType.getNumElements();
591 auto opElemType = arrayType.getElementType();
592 while (
auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
593 numElements *= t.getNumElements();
594 opElemType = t.getElementType();
596 if (!opElemType.isIntOrFloat())
597 return op.emitOpError(
"only support nested array result type");
599 auto valueElemType = shapedType.getElementType();
600 if (valueElemType != opElemType) {
601 return op.emitOpError(
"result element type (")
602 << opElemType <<
") does not match value element type ("
603 << valueElemType <<
")";
606 if (numElements != shapedType.getNumElements()) {
607 return op.emitOpError(
"result number of elements (")
608 << numElements <<
") does not match value number of elements ("
609 << shapedType.getNumElements() <<
")";
613 if (
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
614 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
616 return op.emitOpError(
617 "must have spirv.array result type for array value");
618 Type elemType = arrayType.getElementType();
619 for (
Attribute element : arrayAttr.getValue()) {
626 return op.emitOpError(
"cannot have attribute: ") << value;
636 bool spirv::ConstantOp::isBuildableWith(
Type type) {
638 if (!llvm::isa<spirv::SPIRVType>(type))
643 return llvm::isa<spirv::ArrayType>(type);
651 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
652 unsigned width = intType.getWidth();
654 return spirv::ConstantOp::create(builder, loc, type,
656 return spirv::ConstantOp::create(
657 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 0)));
659 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
660 return spirv::ConstantOp::create(builder, loc, type,
663 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
664 Type elemType = vectorType.getElementType();
665 if (llvm::isa<IntegerType>(elemType)) {
666 return spirv::ConstantOp::create(
671 if (llvm::isa<FloatType>(elemType)) {
672 return spirv::ConstantOp::create(
679 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
682 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
684 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
685 unsigned width = intType.getWidth();
687 return spirv::ConstantOp::create(builder, loc, type,
689 return spirv::ConstantOp::create(
690 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 1)));
692 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
693 return spirv::ConstantOp::create(builder, loc, type,
696 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
697 Type elemType = vectorType.getElementType();
698 if (llvm::isa<IntegerType>(elemType)) {
699 return spirv::ConstantOp::create(
704 if (llvm::isa<FloatType>(elemType)) {
705 return spirv::ConstantOp::create(
712 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
715 void mlir::spirv::ConstantOp::getAsmResultNames(
720 llvm::raw_svector_ostream specialName(specialNameBuffer);
721 specialName <<
"cst";
723 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
725 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
726 if (intTy && intTy.getWidth() == 1) {
727 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
730 if (intTy.isSignless()) {
731 specialName << intCst.getInt();
732 }
else if (intTy.isUnsigned()) {
733 specialName << intCst.getUInt();
735 specialName << intCst.getSInt();
739 if (intTy || llvm::isa<FloatType>(type)) {
740 specialName <<
'_' << type;
743 if (
auto vecType = llvm::dyn_cast<VectorType>(type)) {
744 specialName <<
"_vec_";
745 specialName << vecType.getDimSize(0);
747 Type elementType = vecType.getElementType();
749 if (llvm::isa<IntegerType>(elementType) ||
750 llvm::isa<FloatType>(elementType)) {
751 specialName <<
"x" << elementType;
755 setNameFn(getResult(), specialName.str());
758 void mlir::spirv::AddressOfOp::getAsmResultNames(
761 llvm::raw_svector_ostream specialName(specialNameBuffer);
762 specialName << getVariable() <<
"_addr";
763 setNameFn(getResult(), specialName.str());
774 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775 return typedAttr.getType();
778 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
788 return emitError(
"unknown value attribute type");
790 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
792 return emitError(
"result type is not a composite type");
794 Type compositeElementType = compositeType.getElementType(0);
797 while (
auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
798 compositeElementType = type.getElementType(0);
799 possibleTypes.push_back(compositeElementType);
802 if (!is_contained(possibleTypes, valueType)) {
803 return emitError(
"expected value attribute type ")
804 << interleaved(possibleTypes,
" or ") <<
", but got: " << valueType;
823 spirv::ExecutionModel executionModel,
824 spirv::FuncOp
function,
826 build(builder, state,
833 spirv::ExecutionModel execModel;
837 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
846 FlatSymbolRefAttr var;
848 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
850 interfaceVars.push_back(var);
863 auto interfaceVars = getInterface().getValue();
864 if (!interfaceVars.empty())
865 printer <<
", " << llvm::interleaved(interfaceVars);
879 spirv::FuncOp
function,
880 spirv::ExecutionMode executionMode,
889 spirv::ExecutionMode execMode;
892 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
904 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
906 StringRef valuesAttrName =
907 spirv::ExecutionModeOp::getValuesAttrName(result.
name);
916 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
917 ArrayAttr values = this->getValues();
919 printer <<
", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
939 bool isVariadic =
false;
941 parser,
false, entryArgs, isVariadic, resultTypes,
946 for (
auto &arg : entryArgs)
947 argTypes.push_back(arg.type);
953 spirv::FunctionControl fnControl;
954 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
962 assert(resultAttrs.size() == resultTypes.size());
964 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
965 getResAttrsAttrName(result.
name));
971 return failure(parseResult.
has_value() && failed(*parseResult));
978 auto fnType = getFunctionType();
980 printer, *
this, fnType.getInputs(),
981 false, fnType.getResults());
982 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
986 {spirv::attributeName<spirv::FunctionControl>(),
987 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
988 getFunctionControlAttrName()});
991 Region &body = this->getBody();
999 LogicalResult spirv::FuncOp::verifyType() {
1000 FunctionType fnType = getFunctionType();
1001 if (fnType.getNumResults() > 1)
1002 return emitOpError(
"cannot have more than one result");
1004 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1005 unsigned argIndex) {
1006 auto func = llvm::cast<FunctionOpInterface>(getOperation());
1007 for (
auto argAttr : cast<FunctionOpInterface>(func).
getArgAttrs(argIndex)) {
1008 if (argAttr.getName() != spirv::DecorationAttr::name)
1010 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1011 return decAttr.getValue() == decoration;
1016 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1017 Type param = fnType.getInputs()[i];
1018 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1022 auto pointeePtrType =
1023 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1024 if (pointeePtrType) {
1030 if (pointeePtrType.getStorageClass() !=
1031 spirv::StorageClass::PhysicalStorageBuffer)
1034 bool hasAliasedPtr =
1035 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1036 bool hasRestrictPtr =
1037 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1038 if (!hasAliasedPtr && !hasRestrictPtr)
1039 return emitOpError()
1040 <<
"with a pointer points to a physical buffer pointer must "
1041 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1048 if (
auto pointeeArrayType =
1049 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1051 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1053 pointeePtrType = inputPtrType;
1056 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1057 spirv::StorageClass::PhysicalStorageBuffer)
1060 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1061 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1062 if (!hasAliased && !hasRestrict)
1063 return emitOpError() <<
"with physical buffer pointer must be decorated "
1064 "either 'Aliased' or 'Restrict'";
1070 LogicalResult spirv::FuncOp::verifyBody() {
1071 FunctionType fnType = getFunctionType();
1072 if (!isExternal()) {
1073 Block &entryBlock = front();
1075 unsigned numArguments = this->getNumArguments();
1077 return emitOpError(
"entry block must have ")
1078 << numArguments <<
" arguments to match function signature";
1080 for (
auto [index, fnArgType, blockArgType] :
1082 if (blockArgType != fnArgType) {
1083 return emitOpError(
"type of entry block argument #")
1084 << index <<
'(' << blockArgType
1085 <<
") must match the type of the corresponding argument in "
1086 <<
"function signature(" << fnArgType <<
')';
1092 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1093 if (fnType.getNumResults() != 0)
1094 return retOp.emitOpError(
"cannot be used in functions returning value");
1095 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1096 if (fnType.getNumResults() != 1)
1097 return retOp.emitOpError(
1098 "returns 1 value but enclosing function requires ")
1099 << fnType.getNumResults() <<
" results";
1101 auto retOperandType = retOp.getValue().getType();
1102 auto fnResultType = fnType.getResult(0);
1103 if (retOperandType != fnResultType)
1104 return retOp.emitOpError(
" return value's type (")
1105 << retOperandType <<
") mismatch with function's result type ("
1106 << fnResultType <<
")";
1113 return failure(walkResult.wasInterrupted());
1117 StringRef name, FunctionType type,
1118 spirv::FunctionControl control,
1122 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
1123 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1124 builder.
getAttr<spirv::FunctionControlAttr>(control));
1125 state.attributes.append(attrs.begin(), attrs.end());
1173 Type type, StringRef name,
1174 unsigned descriptorSet,
unsigned binding) {
1177 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1180 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1185 Type type, StringRef name,
1186 spirv::BuiltIn builtin) {
1189 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1196 StringAttr nameAttr;
1197 StringRef initializerAttrName =
1198 spirv::GlobalVariableOp::getInitializerAttrName(result.
name);
1219 StringRef typeAttrName =
1220 spirv::GlobalVariableOp::getTypeAttrName(result.
name);
1225 if (!llvm::isa<spirv::PointerType>(type)) {
1226 return parser.
emitError(loc,
"expected spirv.ptr type");
1235 spirv::attributeName<spirv::StorageClass>()};
1242 StringRef initializerAttrName = this->getInitializerAttrName();
1244 if (
auto initializer = this->getInitializer()) {
1245 printer <<
" " << initializerAttrName <<
'(';
1248 elidedAttrs.push_back(initializerAttrName);
1251 StringRef typeAttrName = this->getTypeAttrName();
1252 elidedAttrs.push_back(typeAttrName);
1254 printer <<
" : " <<
getType();
1258 if (!llvm::isa<spirv::PointerType>(
getType()))
1259 return emitOpError(
"result must be of a !spv.ptr type");
1265 auto storageClass = this->storageClass();
1266 if (storageClass == spirv::StorageClass::Generic ||
1267 storageClass == spirv::StorageClass::Function) {
1268 return emitOpError(
"storage class cannot be '")
1269 << stringifyStorageClass(storageClass) <<
"'";
1273 this->getInitializerAttrName())) {
1275 (*this)->getParentOp(), init.getAttr());
1279 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1280 spirv::SpecConstantCompositeOp>(initOp)) {
1281 return emitOpError(
"initializer must be result of a "
1282 "spirv.SpecConstant or spirv.GlobalVariable or "
1283 "spirv.SpecConstantCompositeOp op");
1308 spirv::StorageClass storageClass;
1319 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1330 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1331 << getValue().getType();
1422 std::optional<StringRef> name) {
1432 spirv::AddressingModel addressingModel,
1433 spirv::MemoryModel memoryModel,
1434 std::optional<VerCapExtAttr> vceTriple,
1435 std::optional<StringRef> name) {
1438 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1439 state.addAttribute(
"memory_model",
1440 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1444 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1455 StringAttr nameAttr;
1460 spirv::AddressingModel addrModel;
1461 spirv::MemoryModel memoryModel;
1462 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1464 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1471 spirv::ModuleOp::getVCETripleAttrName(),
1488 if (std::optional<StringRef> name = getName()) {
1497 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1498 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1499 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1502 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1503 printer <<
" requires " << *triple;
1504 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1512 LogicalResult spirv::ModuleOp::verifyRegions() {
1513 Dialect *dialect = (*this)->getDialect();
1518 for (
auto &op : *getBody()) {
1520 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1525 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1526 auto funcOp =
table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1528 return entryPointOp.emitError(
"function '")
1529 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1531 if (
auto interface = entryPointOp.getInterface()) {
1533 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1535 return entryPointOp.emitError(
1536 "expected symbol reference for interface "
1537 "specification instead of '")
1541 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1543 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1544 "symbol reference instead of'")
1545 << varSymRef <<
"'";
1550 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1551 funcOp, entryPointOp.getExecutionModel());
1552 if (!entryPoints.try_emplace(key, entryPointOp).second)
1553 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1554 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1558 auto linkageAttr = funcOp.getLinkageAttributes();
1559 auto hasImportLinkage =
1560 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1561 spirv::LinkageType::Import);
1562 if (funcOp.isExternal() && !hasImportLinkage)
1564 "'spirv.module' cannot contain external functions "
1565 "without 'Import' linkage_attributes (LinkageAttributes)");
1568 for (
auto &block : funcOp)
1569 for (
auto &op : block) {
1572 "functions in 'spirv.module' can only contain spirv.* ops");
1586 (*this)->getParentOp(), getSpecConstAttr());
1589 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1591 constType = specConstOp.getDefaultValue().getType();
1593 auto specConstCompositeOp =
1594 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1595 if (specConstCompositeOp)
1596 constType = specConstCompositeOp.getType();
1598 if (!specConstOp && !specConstCompositeOp)
1600 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1602 if (getReference().
getType() != constType)
1603 return emitOpError(
"result type mismatch with the referenced "
1604 "specialization constant's type");
1615 StringAttr nameAttr;
1617 StringRef defaultValueAttrName =
1618 spirv::SpecConstantOp::getDefaultValueAttrName(result.
name);
1626 IntegerAttr specIdAttr;
1643 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1645 printer <<
" = " << getDefaultValue();
1649 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1650 if (specID.getValue().isNegative())
1651 return emitOpError(
"SpecId cannot be negative");
1653 auto value = getDefaultValue();
1654 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1656 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1657 return emitOpError(
"default value bitwidth disallowed");
1661 "default value can only be a bool, integer, or float scalar");
1669 VectorType resultType = llvm::cast<VectorType>(
getType());
1671 size_t numResultElements = resultType.getNumElements();
1672 if (numResultElements != getComponents().size())
1673 return emitOpError(
"result type element count (")
1674 << numResultElements
1675 <<
") mismatch with the number of component selectors ("
1676 << getComponents().size() <<
")";
1678 size_t totalSrcElements =
1679 llvm::cast<VectorType>(getVector1().
getType()).getNumElements() +
1680 llvm::cast<VectorType>(getVector2().
getType()).getNumElements();
1682 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1683 uint32_t index = selector.getZExtValue();
1684 if (index >= totalSrcElements &&
1685 index != std::numeric_limits<uint32_t>().
max())
1686 return emitOpError(
"component selector ")
1687 << index <<
" out of range: expected to be in [0, "
1688 << totalSrcElements <<
") or 0xffffffff";
1701 [](
auto matrixType) {
return matrixType.getElementType(); })
1702 .Default([](
Type) {
return nullptr; });
1704 assert(elementType &&
"Unhandled type");
1707 if (getScalar().
getType() != elementType)
1708 return emitOpError(
"input matrix components' type and scaling value must "
1709 "have the same type");
1719 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1720 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1723 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1724 return emitError(
"input matrix rows count must be equal to "
1725 "output matrix columns count");
1727 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1728 return emitError(
"input matrix columns count must be equal to "
1729 "output matrix rows count");
1732 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1733 return emitError(
"input and output matrices must have the same "
1744 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1745 auto vectorType = llvm::cast<VectorType>(getVector().
getType());
1746 auto resultType = llvm::cast<VectorType>(
getType());
1748 if (matrixType.getNumColumns() != vectorType.getNumElements())
1749 return emitOpError(
"matrix columns (")
1750 << matrixType.getNumColumns() <<
") must match vector operand size ("
1751 << vectorType.getNumElements() <<
")";
1753 if (resultType.getNumElements() != matrixType.getNumRows())
1754 return emitOpError(
"result size (")
1755 << resultType.getNumElements() <<
") must match the matrix rows ("
1756 << matrixType.getNumRows() <<
")";
1758 if (matrixType.getElementType() != resultType.getElementType())
1759 return emitOpError(
"matrix and result element types must match");
1769 auto vectorType = llvm::cast<VectorType>(getVector().
getType());
1770 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1771 auto resultType = llvm::cast<VectorType>(
getType());
1773 if (matrixType.getNumRows() != vectorType.getNumElements())
1774 return emitOpError(
"number of components in vector must equal the number "
1775 "of components in each column in matrix");
1777 if (resultType.getNumElements() != matrixType.getNumColumns())
1778 return emitOpError(
"number of columns in matrix must equal the number of "
1779 "components in result");
1781 if (matrixType.getElementType() != resultType.getElementType())
1782 return emitOpError(
"matrix must be a matrix with the same component type "
1783 "as the component type in result");
1793 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().
getType());
1794 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().
getType());
1795 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1798 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1799 return emitError(
"left matrix columns' count must be equal to "
1800 "the right matrix rows' count");
1803 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1805 "right and result matrices must have equal columns' count");
1808 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1809 return emitError(
"right and result matrices' component type must"
1813 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1814 return emitError(
"left and result matrices' component type"
1815 " must be the same");
1818 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1819 return emitError(
"left and result matrices must have equal rows' count");
1831 StringAttr compositeName;
1843 const char *attrName =
"spec_const";
1850 constituents.push_back(specConstRef);
1856 StringAttr compositeSpecConstituentsName =
1857 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.
name);
1865 StringAttr typeAttrName =
1866 spirv::SpecConstantCompositeOp::getTypeAttrName(result.
name);
1875 printer <<
" (" << llvm::interleaved(this->getConstituents().getValue())
1880 auto cType = llvm::dyn_cast<spirv::CompositeType>(
getType());
1881 auto constituents = this->getConstituents().getValue();
1884 return emitError(
"result type must be a composite type, but provided ")
1887 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1888 return emitError(
"unsupported composite type ") << cType;
1889 if (constituents.size() != cType.getNumElements())
1890 return emitError(
"has incorrect number of operands: expected ")
1891 << cType.getNumElements() <<
", but provided "
1892 << constituents.size();
1894 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1895 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1897 auto constituentSpecConstOp =
1899 (*this)->getParentOp(), constituent.getAttr()));
1901 if (constituentSpecConstOp.getDefaultValue().getType() !=
1902 cType.getElementType(index))
1903 return emitError(
"has incorrect types of operands: expected ")
1904 << cType.getElementType(index) <<
", but provided "
1905 << constituentSpecConstOp.getDefaultValue().getType();
1918 StringAttr compositeName;
1920 const char *attrName =
"spec_const";
1931 StringAttr compositeSpecConstituentName =
1932 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1934 result.
addAttribute(compositeSpecConstituentName, specConstRef);
1936 StringAttr typeAttrName =
1937 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(result.
name);
1946 printer <<
" (" << this->getConstituent() <<
") : " <<
getType();
1950 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
1952 return emitError(
"result type must be a composite type, but provided ")
1956 (*this)->getParentOp(), this->getConstituent());
1959 "splat spec constant reference defining constituent not found");
1961 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1962 if (!constituentSpecConstOp)
1963 return emitError(
"constituent is not a spec constant");
1965 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1966 Type compositeElementType = compositeType.getElementType(0);
1967 if (constituentType != compositeElementType)
1968 return emitError(
"constituent has incorrect type: expected ")
1969 << compositeElementType <<
", but provided " << constituentType;
1994 spirv::YieldOp::create(builder, wrappedOp->
getLoc(), wrappedOp->
getResult(0));
2006 printer <<
" wraps ";
2010 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2011 Block &block = getRegion().getBlocks().
front();
2014 return emitOpError(
"expected exactly 2 nested ops");
2019 return emitOpError(
"invalid enclosed op");
2022 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2023 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2025 "invalid operand, must be defined by a constant operation");
2036 llvm::dyn_cast<spirv::StructType>(getResult().
getType());
2039 return emitError(
"result type must be a struct type with two memebers");
2043 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
2044 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
2046 Type operandTy = getOperand().getType();
2047 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
2048 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
2050 if (significandTy != operandTy)
2051 return emitError(
"member zero of the resulting struct type must be the "
2052 "same type as the operand");
2054 if (exponentVecTy) {
2055 IntegerType componentIntTy =
2056 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
2057 if (!componentIntTy || componentIntTy.getWidth() != 32)
2058 return emitError(
"member one of the resulting struct type must"
2059 "be a scalar or vector of 32 bit integer type");
2060 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2061 return emitError(
"member one of the resulting struct type "
2062 "must be a scalar or vector of 32 bit integer type");
2066 if (operandVecTy && exponentVecTy &&
2067 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2070 if (operandFTy && exponentIntTy)
2073 return emitError(
"member one of the resulting struct type must have the same "
2074 "number of components as the operand type");
2082 Type significandType = getX().getType();
2083 Type exponentType = getExp().getType();
2085 if (llvm::isa<FloatType>(significandType) !=
2086 llvm::isa<IntegerType>(exponentType))
2087 return emitOpError(
"operands must both be scalars or vectors");
2090 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
2091 return vectorType.getNumElements();
2096 return emitOpError(
"operands must have the same number of elements");
2131 return emitOpError(
"vector operand and result type mismatch");
2132 auto scalarType = llvm::cast<VectorType>(
getType()).getElementType();
2133 if (getScalar().
getType() != scalarType)
2134 return emitOpError(
"scalar operand and result element type match");
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
static Type getValueType(Attribute attr)
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.