32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
41 #include <type_traits>
51 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
55 auto valueAttr = constOp.getValue();
56 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
57 if (!integerValueAttr) {
61 if (integerValueAttr.getType().isSignlessInteger())
62 value = integerValueAttr.getInt();
64 value = integerValueAttr.getSInt();
71 spirv::MemorySemantics memorySemantics) {
78 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
79 spirv::MemorySemantics::Release |
80 spirv::MemorySemantics::AcquireRelease |
81 spirv::MemorySemantics::SequentiallyConsistent;
84 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
87 "expected at most one of these four memory constraints "
88 "to be set: `Acquire`, `Release`,"
89 "`AcquireRelease` or `SequentiallyConsistent`");
98 stringifyDecoration(spirv::Decoration::DescriptorSet));
99 auto bindingName = llvm::convertToSnakeFromCamelCase(
100 stringifyDecoration(spirv::Decoration::Binding));
103 if (descriptorSet && binding) {
106 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
111 auto builtInName = llvm::convertToSnakeFromCamelCase(
112 stringifyDecoration(spirv::Decoration::BuiltIn));
113 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
114 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
115 elidedAttrs.push_back(builtInName);
133 auto fnType = llvm::dyn_cast<FunctionType>(type);
135 parser.
emitError(loc,
"expected function type");
140 result.
addTypes(fnType.getResults());
151 assert(op->
getNumResults() == 1 &&
"op should have one result");
157 [&](
Type type) { return type != resultType; })) {
166 p <<
" : " << resultType;
169 template <
typename Op>
171 spirv::ImageOperandsAttr attr,
174 if (operands.empty())
177 return imageOp.
emitError(
"the Image Operands should encode what operands "
178 "follow, as per Image Operands");
182 spirv::ImageOperands noSupportOperands =
183 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
184 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
185 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
186 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
187 spirv::ImageOperands::MakeTexelAvailable |
188 spirv::ImageOperands::MakeTexelVisible |
189 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
191 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
192 llvm_unreachable(
"unimplemented operands of Image Operands");
197 template <
typename BlockReadWriteOpTy>
201 if (
auto valVecTy = llvm::dyn_cast<VectorType>(valType))
202 valType = valVecTy.getElementType();
205 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
206 return op.
emitOpError(
"mismatch in result type and pointer type");
217 if (indices.empty()) {
218 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
222 for (
auto index : indices) {
223 if (
auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
224 if (cType.hasCompileTimeKnownNumElements() &&
226 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
227 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
230 type = cType.getElementType(index);
232 emitErrorFn(
"cannot extract from non-composite type ")
233 << type <<
" with index " << index;
243 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
244 if (!indicesArrayAttr) {
245 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
248 if (indicesArrayAttr.empty()) {
249 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
254 for (
auto indexAttr : indicesArrayAttr) {
255 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
257 emitErrorFn(
"expected an 32-bit integer for index, but found '")
261 indexVals.push_back(indexIntAttr.getInt());
281 template <
typename ExtendedBinaryOp>
283 auto resultType = llvm::cast<spirv::StructType>(op.getType());
284 if (resultType.getNumElements() != 2)
285 return op.
emitOpError(
"expected result struct type containing two members");
287 if (!llvm::all_equal({op.getOperand1().
getType(), op.getOperand2().getType(),
288 resultType.getElementType(0),
289 resultType.getElementType(1)}))
291 "expected all operand types and struct member types are the same");
308 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
309 if (!structType || structType.getNumElements() != 2)
310 return parser.
emitError(loc,
"expected spirv.struct type with two members");
330 return op->
emitError(
"expected the same type for the first operand and "
331 "result, but provided ")
343 spirv::GlobalVariableOp var) {
348 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
352 return emitOpError(
"expected spirv.GlobalVariable symbol");
354 if (getPointer().
getType() != varOp.getType()) {
356 "result type mismatch with the referenced global variable's type");
366 operand_range constituents = this->getConstituents();
374 auto coopElementType =
377 [](
auto coopType) {
return coopType.getElementType(); })
378 .Default([](
Type) {
return nullptr; });
381 if (coopElementType) {
382 if (constituents.size() != 1)
383 return emitOpError(
"has incorrect number of operands: expected ")
384 <<
"1, but provided " << constituents.size();
385 if (coopElementType != constituents.front().getType())
386 return emitOpError(
"operand type mismatch: expected operand type ")
387 << coopElementType <<
", but provided "
388 << constituents.front().getType();
393 auto cType = llvm::cast<spirv::CompositeType>(
getType());
394 if (constituents.size() == cType.getNumElements()) {
395 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
396 if (constituents[index].
getType() != cType.getElementType(index)) {
397 return emitOpError(
"operand type mismatch: expected operand type ")
398 << cType.getElementType(index) <<
", but provided "
399 << constituents[index].getType();
406 auto resultType = llvm::dyn_cast<VectorType>(cType);
409 "expected to return a vector or cooperative matrix when the number of "
410 "constituents is less than what the result needs");
413 for (
Value component : constituents) {
414 if (!llvm::isa<VectorType>(component.getType()) &&
415 !component.getType().isIntOrFloat())
416 return emitOpError(
"operand type mismatch: expected operand to have "
417 "a scalar or vector type, but provided ")
418 << component.getType();
420 Type elementType = component.getType();
421 if (
auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
422 sizes.push_back(vectorType.getNumElements());
423 elementType = vectorType.getElementType();
428 if (elementType != resultType.getElementType())
429 return emitOpError(
"operand element type mismatch: expected to be ")
430 << resultType.getElementType() <<
", but provided " << elementType;
432 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
433 if (totalCount != cType.getNumElements())
434 return emitOpError(
"has incorrect number of operands: expected ")
435 << cType.getNumElements() <<
", but provided " << totalCount;
452 build(builder, state, elementType, composite, indexAttr);
459 StringRef indicesAttrName =
460 spirv::CompositeExtractOp::getIndicesAttrName(result.
name);
482 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
487 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
494 return emitOpError(
"invalid result type: expected ")
495 << resultType <<
" but provided " <<
getType();
509 build(builder, state, composite.
getType(),
object, composite, indexAttr);
515 Type objectType, compositeType;
517 StringRef indicesAttrName =
518 spirv::CompositeInsertOp::getIndicesAttrName(result.
name);
532 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
538 if (objectType != getObject().
getType()) {
539 return emitOpError(
"object operand type should be ")
540 << objectType <<
", but found " << getObject().getType();
544 return emitOpError(
"result type should be the same as "
545 "the composite type, but found ")
546 << getComposite().getType() <<
" vs " <<
getType();
553 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
554 <<
" : " << getObject().
getType() <<
" into "
555 << getComposite().getType();
565 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.
name);
570 if (
auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
571 type = typedAttr.getType();
572 if (llvm::isa<NoneType, TensorType>(type)) {
581 printer <<
' ' << getValue();
582 if (llvm::isa<spirv::ArrayType>(
getType()))
588 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
589 auto valueType = llvm::cast<TypedAttr>(value).getType();
590 if (valueType != opType)
592 << opType <<
") does not match value type (" << valueType <<
")";
595 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
596 auto valueType = llvm::cast<TypedAttr>(value).getType();
597 if (valueType == opType)
599 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
600 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
603 << opType <<
") does not match value type (" << valueType
604 <<
"), must be the same or spirv.array";
606 int numElements = arrayType.getNumElements();
607 auto opElemType = arrayType.getElementType();
608 while (
auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
609 numElements *= t.getNumElements();
610 opElemType = t.getElementType();
612 if (!opElemType.isIntOrFloat())
613 return op.
emitOpError(
"only support nested array result type");
615 auto valueElemType = shapedType.getElementType();
616 if (valueElemType != opElemType) {
618 << opElemType <<
") does not match value element type ("
619 << valueElemType <<
")";
622 if (numElements != shapedType.getNumElements()) {
623 return op.
emitOpError(
"result number of elements (")
624 << numElements <<
") does not match value number of elements ("
625 << shapedType.getNumElements() <<
")";
629 if (
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
630 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
633 "must have spirv.array result type for array value");
634 Type elemType = arrayType.getElementType();
635 for (
Attribute element : arrayAttr.getValue()) {
642 return op.
emitOpError(
"cannot have attribute: ") << value;
652 bool spirv::ConstantOp::isBuildableWith(
Type type) {
654 if (!llvm::isa<spirv::SPIRVType>(type))
659 return llvm::isa<spirv::ArrayType>(type);
667 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
668 unsigned width = intType.getWidth();
670 return builder.
create<spirv::ConstantOp>(loc, type,
672 return builder.
create<spirv::ConstantOp>(
675 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
676 return builder.
create<spirv::ConstantOp>(
679 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
680 Type elemType = vectorType.getElementType();
681 if (llvm::isa<IntegerType>(elemType)) {
682 return builder.
create<spirv::ConstantOp>(
687 if (llvm::isa<FloatType>(elemType)) {
688 return builder.
create<spirv::ConstantOp>(
695 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
698 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
700 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
701 unsigned width = intType.getWidth();
703 return builder.
create<spirv::ConstantOp>(loc, type,
705 return builder.
create<spirv::ConstantOp>(
708 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
709 return builder.
create<spirv::ConstantOp>(
712 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
713 Type elemType = vectorType.getElementType();
714 if (llvm::isa<IntegerType>(elemType)) {
715 return builder.
create<spirv::ConstantOp>(
720 if (llvm::isa<FloatType>(elemType)) {
721 return builder.
create<spirv::ConstantOp>(
728 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
731 void mlir::spirv::ConstantOp::getAsmResultNames(
736 llvm::raw_svector_ostream specialName(specialNameBuffer);
737 specialName <<
"cst";
739 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
741 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
742 if (intTy && intTy.getWidth() == 1) {
743 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
746 if (intTy.isSignless()) {
747 specialName << intCst.getInt();
748 }
else if (intTy.isUnsigned()) {
749 specialName << intCst.getUInt();
751 specialName << intCst.getSInt();
755 if (intTy || llvm::isa<FloatType>(type)) {
756 specialName <<
'_' << type;
759 if (
auto vecType = llvm::dyn_cast<VectorType>(type)) {
760 specialName <<
"_vec_";
761 specialName << vecType.getDimSize(0);
763 Type elementType = vecType.getElementType();
765 if (llvm::isa<IntegerType>(elementType) ||
766 llvm::isa<FloatType>(elementType)) {
767 specialName <<
"x" << elementType;
771 setNameFn(getResult(), specialName.str());
774 void mlir::spirv::AddressOfOp::getAsmResultNames(
777 llvm::raw_svector_ostream specialName(specialNameBuffer);
778 specialName << getVariable() <<
"_addr";
779 setNameFn(getResult(), specialName.str());
795 spirv::ExecutionModel executionModel,
796 spirv::FuncOp
function,
798 build(builder, state,
805 spirv::ExecutionModel execModel;
811 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
820 FlatSymbolRefAttr var;
822 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
824 interfaceVars.push_back(var);
837 auto interfaceVars = getInterface().getValue();
838 if (!interfaceVars.empty()) {
840 llvm::interleaveComma(interfaceVars, printer);
855 spirv::FuncOp
function,
856 spirv::ExecutionMode executionMode,
865 spirv::ExecutionMode execMode;
868 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
880 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
882 StringRef valuesAttrName =
883 spirv::ExecutionModeOp::getValuesAttrName(result.
name);
892 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
893 auto values = this->getValues();
897 llvm::interleaveComma(values, printer, [&](
Attribute a) {
898 printer << llvm::cast<IntegerAttr>(a).getInt();
919 bool isVariadic =
false;
921 parser,
false, entryArgs, isVariadic, resultTypes,
926 for (
auto &arg : entryArgs)
927 argTypes.push_back(arg.type);
933 spirv::FunctionControl fnControl;
934 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
942 assert(resultAttrs.size() == resultTypes.size());
944 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
945 getResAttrsAttrName(result.
name));
951 return failure(parseResult.
has_value() && failed(*parseResult));
958 auto fnType = getFunctionType();
960 printer, *
this, fnType.getInputs(),
961 false, fnType.getResults());
962 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
966 {spirv::attributeName<spirv::FunctionControl>(),
967 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
968 getFunctionControlAttrName()});
971 Region &body = this->getBody();
979 LogicalResult spirv::FuncOp::verifyType() {
980 FunctionType fnType = getFunctionType();
981 if (fnType.getNumResults() > 1)
982 return emitOpError(
"cannot have more than one result");
984 auto hasDecorationAttr = [&](spirv::Decoration decoration,
986 auto func = llvm::cast<FunctionOpInterface>(getOperation());
987 for (
auto argAttr : cast<FunctionOpInterface>(func).
getArgAttrs(argIndex)) {
988 if (argAttr.getName() != spirv::DecorationAttr::name)
990 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
991 return decAttr.getValue() == decoration;
996 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
997 Type param = fnType.getInputs()[i];
998 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1002 auto pointeePtrType =
1003 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1004 if (pointeePtrType) {
1010 if (pointeePtrType.getStorageClass() !=
1011 spirv::StorageClass::PhysicalStorageBuffer)
1014 bool hasAliasedPtr =
1015 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1016 bool hasRestrictPtr =
1017 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1018 if (!hasAliasedPtr && !hasRestrictPtr)
1019 return emitOpError()
1020 <<
"with a pointer points to a physical buffer pointer must "
1021 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1028 if (
auto pointeeArrayType =
1029 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1031 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1033 pointeePtrType = inputPtrType;
1036 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1037 spirv::StorageClass::PhysicalStorageBuffer)
1040 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1041 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1042 if (!hasAliased && !hasRestrict)
1043 return emitOpError() <<
"with physical buffer pointer must be decorated "
1044 "either 'Aliased' or 'Restrict'";
1050 LogicalResult spirv::FuncOp::verifyBody() {
1051 FunctionType fnType = getFunctionType();
1054 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1055 if (fnType.getNumResults() != 0)
1056 return retOp.emitOpError(
"cannot be used in functions returning value");
1057 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1058 if (fnType.getNumResults() != 1)
1059 return retOp.emitOpError(
1060 "returns 1 value but enclosing function requires ")
1061 << fnType.getNumResults() <<
" results";
1063 auto retOperandType = retOp.getValue().getType();
1064 auto fnResultType = fnType.getResult(0);
1065 if (retOperandType != fnResultType)
1066 return retOp.emitOpError(
" return value's type (")
1067 << retOperandType <<
") mismatch with function's result type ("
1068 << fnResultType <<
")";
1075 return failure(walkResult.wasInterrupted());
1079 StringRef name, FunctionType type,
1080 spirv::FunctionControl control,
1084 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
1085 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1086 builder.
getAttr<spirv::FunctionControlAttr>(control));
1087 state.attributes.append(attrs.begin(), attrs.end());
1135 Type type, StringRef name,
1136 unsigned descriptorSet,
unsigned binding) {
1139 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1142 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1147 Type type, StringRef name,
1148 spirv::BuiltIn builtin) {
1151 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1158 StringAttr nameAttr;
1159 StringRef initializerAttrName =
1160 spirv::GlobalVariableOp::getInitializerAttrName(result.
name);
1181 StringRef typeAttrName =
1182 spirv::GlobalVariableOp::getTypeAttrName(result.
name);
1187 if (!llvm::isa<spirv::PointerType>(type)) {
1188 return parser.
emitError(loc,
"expected spirv.ptr type");
1197 spirv::attributeName<spirv::StorageClass>()};
1204 StringRef initializerAttrName = this->getInitializerAttrName();
1206 if (
auto initializer = this->getInitializer()) {
1207 printer <<
" " << initializerAttrName <<
'(';
1210 elidedAttrs.push_back(initializerAttrName);
1213 StringRef typeAttrName = this->getTypeAttrName();
1214 elidedAttrs.push_back(typeAttrName);
1216 printer <<
" : " <<
getType();
1220 if (!llvm::isa<spirv::PointerType>(
getType()))
1221 return emitOpError(
"result must be of a !spv.ptr type");
1227 auto storageClass = this->storageClass();
1228 if (storageClass == spirv::StorageClass::Generic ||
1229 storageClass == spirv::StorageClass::Function) {
1230 return emitOpError(
"storage class cannot be '")
1231 << stringifyStorageClass(storageClass) <<
"'";
1235 this->getInitializerAttrName())) {
1237 (*this)->getParentOp(), init.getAttr());
1241 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1242 spirv::SpecConstantCompositeOp>(initOp)) {
1243 return emitOpError(
"initializer must be result of a "
1244 "spirv.SpecConstant or spirv.GlobalVariable or "
1245 "spirv.SpecConstantCompositeOp op");
1259 spirv::StorageClass storageClass;
1268 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1280 printer <<
" " << getPtr() <<
" : " <<
getType();
1297 spirv::StorageClass storageClass;
1308 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1319 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1320 << getValue().getType();
1411 std::optional<StringRef> name) {
1421 spirv::AddressingModel addressingModel,
1422 spirv::MemoryModel memoryModel,
1423 std::optional<VerCapExtAttr> vceTriple,
1424 std::optional<StringRef> name) {
1427 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1428 state.addAttribute(
"memory_model",
1429 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1433 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1444 StringAttr nameAttr;
1449 spirv::AddressingModel addrModel;
1450 spirv::MemoryModel memoryModel;
1451 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1453 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1460 spirv::ModuleOp::getVCETripleAttrName(),
1477 if (std::optional<StringRef> name = getName()) {
1486 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1487 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1488 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1491 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1492 printer <<
" requires " << *triple;
1493 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1501 LogicalResult spirv::ModuleOp::verifyRegions() {
1502 Dialect *dialect = (*this)->getDialect();
1507 for (
auto &op : *getBody()) {
1509 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1514 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1515 auto funcOp =
table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1517 return entryPointOp.emitError(
"function '")
1518 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1520 if (
auto interface = entryPointOp.getInterface()) {
1522 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1524 return entryPointOp.emitError(
1525 "expected symbol reference for interface "
1526 "specification instead of '")
1530 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1532 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1533 "symbol reference instead of'")
1534 << varSymRef <<
"'";
1539 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1540 funcOp, entryPointOp.getExecutionModel());
1541 auto entryPtIt = entryPoints.find(key);
1542 if (entryPtIt != entryPoints.end()) {
1543 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1545 entryPoints[key] = entryPointOp;
1546 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1550 auto linkageAttr = funcOp.getLinkageAttributes();
1551 auto hasImportLinkage =
1552 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1553 spirv::LinkageType::Import);
1554 if (funcOp.isExternal() && !hasImportLinkage)
1556 "'spirv.module' cannot contain external functions "
1557 "without 'Import' linkage_attributes (LinkageAttributes)");
1560 for (
auto &block : funcOp)
1561 for (
auto &op : block) {
1564 "functions in 'spirv.module' can only contain spirv.* ops");
1578 (*this)->getParentOp(), getSpecConstAttr());
1581 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1583 constType = specConstOp.getDefaultValue().getType();
1585 auto specConstCompositeOp =
1586 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1587 if (specConstCompositeOp)
1588 constType = specConstCompositeOp.getType();
1590 if (!specConstOp && !specConstCompositeOp)
1592 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1594 if (getReference().
getType() != constType)
1595 return emitOpError(
"result type mismatch with the referenced "
1596 "specialization constant's type");
1607 StringAttr nameAttr;
1609 StringRef defaultValueAttrName =
1610 spirv::SpecConstantOp::getDefaultValueAttrName(result.
name);
1618 IntegerAttr specIdAttr;
1635 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1637 printer <<
" = " << getDefaultValue();
1641 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1642 if (specID.getValue().isNegative())
1643 return emitOpError(
"SpecId cannot be negative");
1645 auto value = getDefaultValue();
1646 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1648 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1649 return emitOpError(
"default value bitwidth disallowed");
1653 "default value can only be a bool, integer, or float scalar");
1661 VectorType resultType = llvm::cast<VectorType>(
getType());
1663 size_t numResultElements = resultType.getNumElements();
1664 if (numResultElements != getComponents().size())
1665 return emitOpError(
"result type element count (")
1666 << numResultElements
1667 <<
") mismatch with the number of component selectors ("
1668 << getComponents().size() <<
")";
1670 size_t totalSrcElements =
1671 llvm::cast<VectorType>(getVector1().
getType()).getNumElements() +
1672 llvm::cast<VectorType>(getVector2().
getType()).getNumElements();
1674 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1675 uint32_t index = selector.getZExtValue();
1676 if (index >= totalSrcElements &&
1677 index != std::numeric_limits<uint32_t>().
max())
1678 return emitOpError(
"component selector ")
1679 << index <<
" out of range: expected to be in [0, "
1680 << totalSrcElements <<
") or 0xffffffff";
1693 [](
auto matrixType) {
return matrixType.getElementType(); })
1694 .Default([](
Type) {
return nullptr; });
1696 assert(elementType &&
"Unhandled type");
1699 if (getScalar().
getType() != elementType)
1700 return emitOpError(
"input matrix components' type and scaling value must "
1701 "have the same type");
1711 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1712 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1715 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1716 return emitError(
"input matrix rows count must be equal to "
1717 "output matrix columns count");
1719 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1720 return emitError(
"input matrix columns count must be equal to "
1721 "output matrix rows count");
1724 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1725 return emitError(
"input and output matrices must have the same "
1736 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().
getType());
1737 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().
getType());
1738 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1741 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1742 return emitError(
"left matrix columns' count must be equal to "
1743 "the right matrix rows' count");
1746 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1748 "right and result matrices must have equal columns' count");
1751 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1752 return emitError(
"right and result matrices' component type must"
1756 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1757 return emitError(
"left and result matrices' component type"
1758 " must be the same");
1761 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1762 return emitError(
"left and result matrices must have equal rows' count");
1774 StringAttr compositeName;
1786 const char *attrName =
"spec_const";
1793 constituents.push_back(specConstRef);
1799 StringAttr compositeSpecConstituentsName =
1800 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.
name);
1808 StringAttr typeAttrName =
1809 spirv::SpecConstantCompositeOp::getTypeAttrName(result.
name);
1819 auto constituents = this->getConstituents().getValue();
1821 if (!constituents.empty())
1822 llvm::interleaveComma(constituents, printer);
1824 printer <<
") : " <<
getType();
1828 auto cType = llvm::dyn_cast<spirv::CompositeType>(
getType());
1829 auto constituents = this->getConstituents().getValue();
1832 return emitError(
"result type must be a composite type, but provided ")
1835 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1836 return emitError(
"unsupported composite type ") << cType;
1837 if (constituents.size() != cType.getNumElements())
1838 return emitError(
"has incorrect number of operands: expected ")
1839 << cType.getNumElements() <<
", but provided "
1840 << constituents.size();
1842 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1843 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1845 auto constituentSpecConstOp =
1847 (*this)->getParentOp(), constituent.getAttr()));
1849 if (constituentSpecConstOp.getDefaultValue().getType() !=
1850 cType.getElementType(index))
1851 return emitError(
"has incorrect types of operands: expected ")
1852 << cType.getElementType(index) <<
", but provided "
1853 << constituentSpecConstOp.getDefaultValue().getType();
1891 printer <<
" wraps ";
1895 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1896 Block &block = getRegion().getBlocks().
front();
1899 return emitOpError(
"expected exactly 2 nested ops");
1904 return emitOpError(
"invalid enclosed op");
1907 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1908 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1910 "invalid operand, must be defined by a constant operation");
1921 llvm::dyn_cast<spirv::StructType>(getResult().
getType());
1924 return emitError(
"result type must be a struct type with two memebers");
1928 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1929 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1931 Type operandTy = getOperand().getType();
1932 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1933 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1935 if (significandTy != operandTy)
1936 return emitError(
"member zero of the resulting struct type must be the "
1937 "same type as the operand");
1939 if (exponentVecTy) {
1940 IntegerType componentIntTy =
1941 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1942 if (!componentIntTy || componentIntTy.getWidth() != 32)
1943 return emitError(
"member one of the resulting struct type must"
1944 "be a scalar or vector of 32 bit integer type");
1945 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1946 return emitError(
"member one of the resulting struct type "
1947 "must be a scalar or vector of 32 bit integer type");
1951 if (operandVecTy && exponentVecTy &&
1952 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1955 if (operandFTy && exponentIntTy)
1958 return emitError(
"member one of the resulting struct type must have the same "
1959 "number of components as the operand type");
1967 Type significandType = getX().getType();
1968 Type exponentType = getExp().getType();
1970 if (llvm::isa<FloatType>(significandType) !=
1971 llvm::isa<IntegerType>(exponentType))
1972 return emitOpError(
"operands must both be scalars or vectors");
1975 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
1976 return vectorType.getNumElements();
1981 return emitOpError(
"operands must have the same number of elements");
1991 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
1992 auto sampledImageType =
1993 llvm::cast<spirv::SampledImageType>(getSampledimage().
getType());
1995 llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
1997 if (resultType.getNumElements() != 4)
1998 return emitOpError(
"result type must be a vector of four components");
2000 Type elementType = resultType.getElementType();
2001 Type sampledElementType = imageType.getElementType();
2002 if (!llvm::isa<NoneType>(sampledElementType) &&
2003 elementType != sampledElementType)
2005 "the component type of result must be the same as sampled type of the "
2006 "underlying image type");
2008 spirv::Dim imageDim = imageType.getDim();
2009 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
2011 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
2012 imageDim != spirv::Dim::Rect)
2014 "the Dim operand of the underlying image type must be 2D, Cube, or "
2017 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
2018 return emitOpError(
"the MS operand of the underlying image type must be 0");
2020 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
2021 auto operandArguments = getOperandArguments();
2056 llvm::cast<spirv::ImageType>(getImage().
getType());
2057 Type resultType = getResult().getType();
2059 spirv::Dim dim = imageType.
getDim();
2063 case spirv::Dim::Dim1D:
2064 case spirv::Dim::Dim2D:
2065 case spirv::Dim::Dim3D:
2066 case spirv::Dim::Cube:
2067 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2068 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2069 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2071 "if Dim is 1D, 2D, 3D, or Cube, "
2072 "it must also have either an MS of 1 or a Sampled of 0 or 2");
2074 case spirv::Dim::Buffer:
2075 case spirv::Dim::Rect:
2078 return emitError(
"the Dim operand of the image type must "
2079 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2082 unsigned componentNumber = 0;
2084 case spirv::Dim::Dim1D:
2085 case spirv::Dim::Buffer:
2086 componentNumber = 1;
2088 case spirv::Dim::Dim2D:
2089 case spirv::Dim::Cube:
2090 case spirv::Dim::Rect:
2091 componentNumber = 2;
2093 case spirv::Dim::Dim3D:
2094 componentNumber = 3;
2100 if (imageType.
getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2101 componentNumber += 1;
2103 unsigned resultComponentNumber = 1;
2104 if (
auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2105 resultComponentNumber = resultVectorType.getNumElements();
2107 if (componentNumber != resultComponentNumber)
2108 return emitError(
"expected the result to have ")
2109 << componentNumber <<
" component(s), but found "
2110 << resultComponentNumber <<
" component(s)";
2121 return emitOpError(
"vector operand and result type mismatch");
2122 auto scalarType = llvm::cast<VectorType>(
getType()).getElementType();
2123 if (getScalar().
getType() != scalarType)
2124 return emitOpError(
"scalar operand and result element type match");
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
ImageArrayedInfo getArrayedInfo() const
ImageSamplerUseInfo getSamplerUseInfo() const
ImageSamplingInfo getSamplingInfo() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.