33 #include "llvm/ADT/APFloat.h"
34 #include "llvm/ADT/APInt.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/TypeSwitch.h"
42 #include <type_traits>
52 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
56 auto valueAttr = constOp.getValue();
57 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58 if (!integerValueAttr) {
62 if (integerValueAttr.getType().isSignlessInteger())
63 value = integerValueAttr.getInt();
65 value = integerValueAttr.getSInt();
72 spirv::MemorySemantics memorySemantics) {
79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80 spirv::MemorySemantics::Release |
81 spirv::MemorySemantics::AcquireRelease |
82 spirv::MemorySemantics::SequentiallyConsistent;
85 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
88 "expected at most one of these four memory constraints "
89 "to be set: `Acquire`, `Release`,"
90 "`AcquireRelease` or `SequentiallyConsistent`");
99 stringifyDecoration(spirv::Decoration::DescriptorSet));
100 auto bindingName = llvm::convertToSnakeFromCamelCase(
101 stringifyDecoration(spirv::Decoration::Binding));
104 if (descriptorSet && binding) {
107 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
112 auto builtInName = llvm::convertToSnakeFromCamelCase(
113 stringifyDecoration(spirv::Decoration::BuiltIn));
114 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
115 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
116 elidedAttrs.push_back(builtInName);
134 auto fnType = llvm::dyn_cast<FunctionType>(type);
136 parser.
emitError(loc,
"expected function type");
141 result.
addTypes(fnType.getResults());
152 assert(op->
getNumResults() == 1 &&
"op should have one result");
158 [&](
Type type) { return type != resultType; })) {
167 p <<
" : " << resultType;
170 template <
typename Op>
172 spirv::ImageOperandsAttr attr,
175 if (operands.empty())
178 return imageOp.
emitError(
"the Image Operands should encode what operands "
179 "follow, as per Image Operands");
183 spirv::ImageOperands noSupportOperands =
184 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
185 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
186 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
187 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
188 spirv::ImageOperands::MakeTexelAvailable |
189 spirv::ImageOperands::MakeTexelVisible |
190 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
192 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
193 llvm_unreachable(
"unimplemented operands of Image Operands");
198 template <
typename BlockReadWriteOpTy>
202 if (
auto valVecTy = llvm::dyn_cast<VectorType>(valType))
203 valType = valVecTy.getElementType();
206 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
207 return op.
emitOpError(
"mismatch in result type and pointer type");
218 if (indices.empty()) {
219 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
223 for (
auto index : indices) {
224 if (
auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
225 if (cType.hasCompileTimeKnownNumElements() &&
227 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
228 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
231 type = cType.getElementType(index);
233 emitErrorFn(
"cannot extract from non-composite type ")
234 << type <<
" with index " << index;
244 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
245 if (!indicesArrayAttr) {
246 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
249 if (indicesArrayAttr.empty()) {
250 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
255 for (
auto indexAttr : indicesArrayAttr) {
256 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
258 emitErrorFn(
"expected an 32-bit integer for index, but found '")
262 indexVals.push_back(indexIntAttr.getInt());
282 template <
typename ExtendedBinaryOp>
284 auto resultType = llvm::cast<spirv::StructType>(op.getType());
285 if (resultType.getNumElements() != 2)
286 return op.
emitOpError(
"expected result struct type containing two members");
288 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
289 resultType.getElementType(0),
290 resultType.getElementType(1)}))
292 "expected all operand types and struct member types are the same");
309 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
310 if (!structType || structType.getNumElements() != 2)
311 return parser.
emitError(loc,
"expected spirv.struct type with two members");
331 return op->
emitError(
"expected the same type for the first operand and "
332 "result, but provided ")
344 spirv::GlobalVariableOp var) {
349 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
353 return emitOpError(
"expected spirv.GlobalVariable symbol");
355 if (getPointer().getType() != varOp.getType()) {
357 "result type mismatch with the referenced global variable's type");
367 operand_range constituents = this->getConstituents();
375 auto coopElementType =
378 [](
auto coopType) {
return coopType.getElementType(); })
379 .Default([](
Type) {
return nullptr; });
382 if (coopElementType) {
383 if (constituents.size() != 1)
384 return emitOpError(
"has incorrect number of operands: expected ")
385 <<
"1, but provided " << constituents.size();
386 if (coopElementType != constituents.front().getType())
387 return emitOpError(
"operand type mismatch: expected operand type ")
388 << coopElementType <<
", but provided "
389 << constituents.front().getType();
394 auto cType = llvm::cast<spirv::CompositeType>(getType());
395 if (constituents.size() == cType.getNumElements()) {
396 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
397 if (constituents[index].getType() != cType.getElementType(index)) {
398 return emitOpError(
"operand type mismatch: expected operand type ")
399 << cType.getElementType(index) <<
", but provided "
400 << constituents[index].getType();
407 auto resultType = llvm::dyn_cast<VectorType>(cType);
410 "expected to return a vector or cooperative matrix when the number of "
411 "constituents is less than what the result needs");
414 for (
Value component : constituents) {
415 if (!llvm::isa<VectorType>(component.getType()) &&
416 !component.getType().isIntOrFloat())
417 return emitOpError(
"operand type mismatch: expected operand to have "
418 "a scalar or vector type, but provided ")
419 << component.getType();
421 Type elementType = component.getType();
422 if (
auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
423 sizes.push_back(vectorType.getNumElements());
424 elementType = vectorType.getElementType();
429 if (elementType != resultType.getElementType())
430 return emitOpError(
"operand element type mismatch: expected to be ")
431 << resultType.getElementType() <<
", but provided " << elementType;
433 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
434 if (totalCount != cType.getNumElements())
435 return emitOpError(
"has incorrect number of operands: expected ")
436 << cType.getNumElements() <<
", but provided " << totalCount;
453 build(builder, state, elementType, composite, indexAttr);
460 StringRef indicesAttrName =
461 spirv::CompositeExtractOp::getIndicesAttrName(result.
name);
483 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
488 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
490 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
494 if (resultType != getType()) {
495 return emitOpError(
"invalid result type: expected ")
496 << resultType <<
" but provided " << getType();
510 build(builder, state, composite.
getType(),
object, composite, indexAttr);
516 Type objectType, compositeType;
518 StringRef indicesAttrName =
519 spirv::CompositeInsertOp::getIndicesAttrName(result.
name);
533 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
535 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
539 if (objectType != getObject().getType()) {
540 return emitOpError(
"object operand type should be ")
541 << objectType <<
", but found " << getObject().getType();
544 if (getComposite().getType() != getType()) {
545 return emitOpError(
"result type should be the same as "
546 "the composite type, but found ")
547 << getComposite().getType() <<
" vs " << getType();
554 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
555 <<
" : " << getObject().
getType() <<
" into "
556 << getComposite().getType();
566 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.
name);
571 if (
auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
572 type = typedAttr.getType();
573 if (llvm::isa<NoneType, TensorType>(type)) {
582 printer <<
' ' << getValue();
583 if (llvm::isa<spirv::ArrayType>(getType()))
584 printer <<
" : " << getType();
589 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
590 auto valueType = llvm::cast<TypedAttr>(value).getType();
591 if (valueType != opType)
593 << opType <<
") does not match value type (" << valueType <<
")";
596 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
597 auto valueType = llvm::cast<TypedAttr>(value).getType();
598 if (valueType == opType)
600 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
601 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
604 << opType <<
") does not match value type (" << valueType
605 <<
"), must be the same or spirv.array";
607 int numElements = arrayType.getNumElements();
608 auto opElemType = arrayType.getElementType();
609 while (
auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
610 numElements *= t.getNumElements();
611 opElemType = t.getElementType();
613 if (!opElemType.isIntOrFloat())
614 return op.
emitOpError(
"only support nested array result type");
616 auto valueElemType = shapedType.getElementType();
617 if (valueElemType != opElemType) {
619 << opElemType <<
") does not match value element type ("
620 << valueElemType <<
")";
623 if (numElements != shapedType.getNumElements()) {
624 return op.
emitOpError(
"result number of elements (")
625 << numElements <<
") does not match value number of elements ("
626 << shapedType.getNumElements() <<
")";
630 if (
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
631 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
634 "must have spirv.array result type for array value");
635 Type elemType = arrayType.getElementType();
636 for (
Attribute element : arrayAttr.getValue()) {
643 return op.
emitOpError(
"cannot have attribute: ") << value;
653 bool spirv::ConstantOp::isBuildableWith(
Type type) {
655 if (!llvm::isa<spirv::SPIRVType>(type))
660 return llvm::isa<spirv::ArrayType>(type);
668 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
669 unsigned width = intType.getWidth();
671 return builder.
create<spirv::ConstantOp>(loc, type,
673 return builder.
create<spirv::ConstantOp>(
676 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
677 return builder.
create<spirv::ConstantOp>(
680 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
681 Type elemType = vectorType.getElementType();
682 if (llvm::isa<IntegerType>(elemType)) {
683 return builder.
create<spirv::ConstantOp>(
688 if (llvm::isa<FloatType>(elemType)) {
689 return builder.
create<spirv::ConstantOp>(
696 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
699 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
701 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
702 unsigned width = intType.getWidth();
704 return builder.
create<spirv::ConstantOp>(loc, type,
706 return builder.
create<spirv::ConstantOp>(
709 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
710 return builder.
create<spirv::ConstantOp>(
713 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
714 Type elemType = vectorType.getElementType();
715 if (llvm::isa<IntegerType>(elemType)) {
716 return builder.
create<spirv::ConstantOp>(
721 if (llvm::isa<FloatType>(elemType)) {
722 return builder.
create<spirv::ConstantOp>(
729 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
732 void mlir::spirv::ConstantOp::getAsmResultNames(
734 Type type = getType();
737 llvm::raw_svector_ostream specialName(specialNameBuffer);
738 specialName <<
"cst";
740 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
742 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
743 if (intTy && intTy.getWidth() == 1) {
744 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
747 if (intTy.isSignless()) {
748 specialName << intCst.getInt();
749 }
else if (intTy.isUnsigned()) {
750 specialName << intCst.getUInt();
752 specialName << intCst.getSInt();
756 if (intTy || llvm::isa<FloatType>(type)) {
757 specialName <<
'_' << type;
760 if (
auto vecType = llvm::dyn_cast<VectorType>(type)) {
761 specialName <<
"_vec_";
762 specialName << vecType.getDimSize(0);
764 Type elementType = vecType.getElementType();
766 if (llvm::isa<IntegerType>(elementType) ||
767 llvm::isa<FloatType>(elementType)) {
768 specialName <<
"x" << elementType;
772 setNameFn(getResult(), specialName.str());
775 void mlir::spirv::AddressOfOp::getAsmResultNames(
778 llvm::raw_svector_ostream specialName(specialNameBuffer);
779 specialName << getVariable() <<
"_addr";
780 setNameFn(getResult(), specialName.str());
796 spirv::ExecutionModel executionModel,
797 spirv::FuncOp
function,
799 build(builder, state,
806 spirv::ExecutionModel execModel;
812 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
821 FlatSymbolRefAttr var;
823 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
825 interfaceVars.push_back(var);
838 auto interfaceVars = getInterface().getValue();
839 if (!interfaceVars.empty()) {
841 llvm::interleaveComma(interfaceVars, printer);
856 spirv::FuncOp
function,
857 spirv::ExecutionMode executionMode,
866 spirv::ExecutionMode execMode;
869 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
881 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
883 StringRef valuesAttrName =
884 spirv::ExecutionModeOp::getValuesAttrName(result.
name);
893 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
894 auto values = this->getValues();
898 llvm::interleaveComma(values, printer, [&](
Attribute a) {
899 printer << llvm::cast<IntegerAttr>(a).getInt();
920 bool isVariadic =
false;
922 parser,
false, entryArgs, isVariadic, resultTypes,
927 for (
auto &arg : entryArgs)
928 argTypes.push_back(arg.type);
934 spirv::FunctionControl fnControl;
935 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
943 assert(resultAttrs.size() == resultTypes.size());
945 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
946 getResAttrsAttrName(result.
name));
959 auto fnType = getFunctionType();
961 printer, *
this, fnType.getInputs(),
962 false, fnType.getResults());
963 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
967 {spirv::attributeName<spirv::FunctionControl>(),
968 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
969 getFunctionControlAttrName()});
972 Region &body = this->getBody();
981 FunctionType fnType = getFunctionType();
982 if (fnType.getNumResults() > 1)
983 return emitOpError(
"cannot have more than one result");
985 auto hasDecorationAttr = [&](spirv::Decoration decoration,
987 auto func = llvm::cast<FunctionOpInterface>(getOperation());
988 for (
auto argAttr : cast<FunctionOpInterface>(func).
getArgAttrs(argIndex)) {
989 if (argAttr.getName() != spirv::DecorationAttr::name)
991 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
992 return decAttr.getValue() == decoration;
997 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
998 Type param = fnType.getInputs()[i];
999 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1003 auto pointeePtrType =
1004 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1005 if (pointeePtrType) {
1011 if (pointeePtrType.getStorageClass() !=
1012 spirv::StorageClass::PhysicalStorageBuffer)
1015 bool hasAliasedPtr =
1016 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1017 bool hasRestrictPtr =
1018 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1019 if (!hasAliasedPtr && !hasRestrictPtr)
1020 return emitOpError()
1021 <<
"with a pointer points to a physical buffer pointer must "
1022 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1029 if (
auto pointeeArrayType =
1030 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1032 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1034 pointeePtrType = inputPtrType;
1037 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1038 spirv::StorageClass::PhysicalStorageBuffer)
1041 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1042 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1043 if (!hasAliased && !hasRestrict)
1044 return emitOpError() <<
"with physical buffer pointer must be decorated "
1045 "either 'Aliased' or 'Restrict'";
1052 FunctionType fnType = getFunctionType();
1055 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1056 if (fnType.getNumResults() != 0)
1057 return retOp.emitOpError(
"cannot be used in functions returning value");
1058 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1059 if (fnType.getNumResults() != 1)
1060 return retOp.emitOpError(
1061 "returns 1 value but enclosing function requires ")
1062 << fnType.getNumResults() <<
" results";
1064 auto retOperandType = retOp.getValue().getType();
1065 auto fnResultType = fnType.getResult(0);
1066 if (retOperandType != fnResultType)
1067 return retOp.emitOpError(
" return value's type (")
1068 << retOperandType <<
") mismatch with function's result type ("
1069 << fnResultType <<
")";
1076 return failure(walkResult.wasInterrupted());
1080 StringRef name, FunctionType type,
1081 spirv::FunctionControl control,
1085 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
1086 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1087 builder.
getAttr<spirv::FunctionControlAttr>(control));
1088 state.attributes.append(attrs.begin(), attrs.end());
1136 Type type, StringRef name,
1137 unsigned descriptorSet,
unsigned binding) {
1140 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1143 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1148 Type type, StringRef name,
1149 spirv::BuiltIn builtin) {
1152 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1159 StringAttr nameAttr;
1160 StringRef initializerAttrName =
1161 spirv::GlobalVariableOp::getInitializerAttrName(result.
name);
1182 StringRef typeAttrName =
1183 spirv::GlobalVariableOp::getTypeAttrName(result.
name);
1188 if (!llvm::isa<spirv::PointerType>(type)) {
1189 return parser.
emitError(loc,
"expected spirv.ptr type");
1198 spirv::attributeName<spirv::StorageClass>()};
1205 StringRef initializerAttrName = this->getInitializerAttrName();
1207 if (
auto initializer = this->getInitializer()) {
1208 printer <<
" " << initializerAttrName <<
'(';
1211 elidedAttrs.push_back(initializerAttrName);
1214 StringRef typeAttrName = this->getTypeAttrName();
1215 elidedAttrs.push_back(typeAttrName);
1217 printer <<
" : " << getType();
1221 if (!llvm::isa<spirv::PointerType>(getType()))
1222 return emitOpError(
"result must be of a !spv.ptr type");
1228 auto storageClass = this->storageClass();
1229 if (storageClass == spirv::StorageClass::Generic ||
1230 storageClass == spirv::StorageClass::Function) {
1231 return emitOpError(
"storage class cannot be '")
1232 << stringifyStorageClass(storageClass) <<
"'";
1236 this->getInitializerAttrName())) {
1238 (*this)->getParentOp(), init.getAttr());
1242 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1243 spirv::SpecConstantCompositeOp>(initOp)) {
1244 return emitOpError(
"initializer must be result of a "
1245 "spirv.SpecConstant or spirv.GlobalVariable or "
1246 "spirv.SpecConstantCompositeOp op");
1260 spirv::StorageClass storageClass;
1269 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1281 printer <<
" " << getPtr() <<
" : " << getType();
1298 spirv::StorageClass storageClass;
1309 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1320 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1321 << getValue().getType();
1412 std::optional<StringRef> name) {
1422 spirv::AddressingModel addressingModel,
1423 spirv::MemoryModel memoryModel,
1424 std::optional<VerCapExtAttr> vceTriple,
1425 std::optional<StringRef> name) {
1428 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1429 state.addAttribute(
"memory_model",
1430 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1434 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1445 StringAttr nameAttr;
1450 spirv::AddressingModel addrModel;
1451 spirv::MemoryModel memoryModel;
1452 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1454 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1461 spirv::ModuleOp::getVCETripleAttrName(),
1478 if (std::optional<StringRef> name = getName()) {
1487 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1488 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1489 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1492 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1493 printer <<
" requires " << *triple;
1494 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1503 Dialect *dialect = (*this)->getDialect();
1508 for (
auto &op : *getBody()) {
1510 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1515 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1516 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1518 return entryPointOp.emitError(
"function '")
1519 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1521 if (
auto interface = entryPointOp.getInterface()) {
1523 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1525 return entryPointOp.emitError(
1526 "expected symbol reference for interface "
1527 "specification instead of '")
1531 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1533 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1534 "symbol reference instead of'")
1535 << varSymRef <<
"'";
1540 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1541 funcOp, entryPointOp.getExecutionModel());
1542 auto entryPtIt = entryPoints.find(key);
1543 if (entryPtIt != entryPoints.end()) {
1544 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1546 entryPoints[key] = entryPointOp;
1547 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1551 auto linkageAttr = funcOp.getLinkageAttributes();
1552 auto hasImportLinkage =
1553 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1554 spirv::LinkageType::Import);
1555 if (funcOp.isExternal() && !hasImportLinkage)
1557 "'spirv.module' cannot contain external functions "
1558 "without 'Import' linkage_attributes (LinkageAttributes)");
1561 for (
auto &block : funcOp)
1562 for (
auto &op : block) {
1565 "functions in 'spirv.module' can only contain spirv.* ops");
1579 (*this)->getParentOp(), getSpecConstAttr());
1582 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1584 constType = specConstOp.getDefaultValue().getType();
1586 auto specConstCompositeOp =
1587 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1588 if (specConstCompositeOp)
1589 constType = specConstCompositeOp.getType();
1591 if (!specConstOp && !specConstCompositeOp)
1593 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1595 if (getReference().getType() != constType)
1596 return emitOpError(
"result type mismatch with the referenced "
1597 "specialization constant's type");
1608 StringAttr nameAttr;
1610 StringRef defaultValueAttrName =
1611 spirv::SpecConstantOp::getDefaultValueAttrName(result.
name);
1619 IntegerAttr specIdAttr;
1636 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1638 printer <<
" = " << getDefaultValue();
1642 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1643 if (specID.getValue().isNegative())
1644 return emitOpError(
"SpecId cannot be negative");
1646 auto value = getDefaultValue();
1647 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1649 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1650 return emitOpError(
"default value bitwidth disallowed");
1654 "default value can only be a bool, integer, or float scalar");
1662 VectorType resultType = llvm::cast<VectorType>(getType());
1664 size_t numResultElements = resultType.getNumElements();
1665 if (numResultElements != getComponents().size())
1666 return emitOpError(
"result type element count (")
1667 << numResultElements
1668 <<
") mismatch with the number of component selectors ("
1669 << getComponents().size() <<
")";
1671 size_t totalSrcElements =
1672 llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1673 llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1675 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1676 uint32_t index = selector.getZExtValue();
1677 if (index >= totalSrcElements &&
1678 index != std::numeric_limits<uint32_t>().
max())
1679 return emitOpError(
"component selector ")
1680 << index <<
" out of range: expected to be in [0, "
1681 << totalSrcElements <<
") or 0xffffffff";
1694 [](
auto matrixType) {
return matrixType.getElementType(); })
1695 .Default([](
Type) {
return nullptr; });
1697 assert(elementType &&
"Unhandled type");
1700 if (getScalar().getType() != elementType)
1701 return emitOpError(
"input matrix components' type and scaling value must "
1702 "have the same type");
1712 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1713 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1716 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1717 return emitError(
"input matrix rows count must be equal to "
1718 "output matrix columns count");
1720 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1721 return emitError(
"input matrix columns count must be equal to "
1722 "output matrix rows count");
1725 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1726 return emitError(
"input and output matrices must have the same "
1737 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1738 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1739 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1742 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1743 return emitError(
"left matrix columns' count must be equal to "
1744 "the right matrix rows' count");
1747 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1749 "right and result matrices must have equal columns' count");
1752 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1753 return emitError(
"right and result matrices' component type must"
1757 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1758 return emitError(
"left and result matrices' component type"
1759 " must be the same");
1762 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1763 return emitError(
"left and result matrices must have equal rows' count");
1775 StringAttr compositeName;
1787 const char *attrName =
"spec_const";
1794 constituents.push_back(specConstRef);
1800 StringAttr compositeSpecConstituentsName =
1801 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.
name);
1809 StringAttr typeAttrName =
1810 spirv::SpecConstantCompositeOp::getTypeAttrName(result.
name);
1820 auto constituents = this->getConstituents().getValue();
1822 if (!constituents.empty())
1823 llvm::interleaveComma(constituents, printer);
1825 printer <<
") : " << getType();
1829 auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1830 auto constituents = this->getConstituents().getValue();
1833 return emitError(
"result type must be a composite type, but provided ")
1836 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1837 return emitError(
"unsupported composite type ") << cType;
1838 if (llvm::isa<spirv::JointMatrixINTELType>(cType))
1839 return emitError(
"unsupported composite type ") << cType;
1840 if (constituents.size() != cType.getNumElements())
1841 return emitError(
"has incorrect number of operands: expected ")
1842 << cType.getNumElements() <<
", but provided "
1843 << constituents.size();
1845 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1846 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1848 auto constituentSpecConstOp =
1850 (*this)->getParentOp(), constituent.getAttr()));
1852 if (constituentSpecConstOp.getDefaultValue().getType() !=
1853 cType.getElementType(index))
1854 return emitError(
"has incorrect types of operands: expected ")
1855 << cType.getElementType(index) <<
", but provided "
1856 << constituentSpecConstOp.getDefaultValue().getType();
1894 printer <<
" wraps ";
1898 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1899 Block &block = getRegion().getBlocks().
front();
1902 return emitOpError(
"expected exactly 2 nested ops");
1907 return emitOpError(
"invalid enclosed op");
1910 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1911 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1913 "invalid operand, must be defined by a constant operation");
1924 llvm::dyn_cast<spirv::StructType>(getResult().getType());
1927 return emitError(
"result type must be a struct type with two memebers");
1931 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1932 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1934 Type operandTy = getOperand().getType();
1935 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1936 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1938 if (significandTy != operandTy)
1939 return emitError(
"member zero of the resulting struct type must be the "
1940 "same type as the operand");
1942 if (exponentVecTy) {
1943 IntegerType componentIntTy =
1944 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1945 if (!componentIntTy || componentIntTy.getWidth() != 32)
1946 return emitError(
"member one of the resulting struct type must"
1947 "be a scalar or vector of 32 bit integer type");
1948 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1949 return emitError(
"member one of the resulting struct type "
1950 "must be a scalar or vector of 32 bit integer type");
1954 if (operandVecTy && exponentVecTy &&
1955 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1958 if (operandFTy && exponentIntTy)
1961 return emitError(
"member one of the resulting struct type must have the same "
1962 "number of components as the operand type");
1970 Type significandType = getX().getType();
1971 Type exponentType = getExp().getType();
1973 if (llvm::isa<FloatType>(significandType) !=
1974 llvm::isa<IntegerType>(exponentType))
1975 return emitOpError(
"operands must both be scalars or vectors");
1978 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
1979 return vectorType.getNumElements();
1984 return emitOpError(
"operands must have the same number of elements");
1994 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
1995 auto sampledImageType =
1996 llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
1998 llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
2000 if (resultType.getNumElements() != 4)
2001 return emitOpError(
"result type must be a vector of four components");
2003 Type elementType = resultType.getElementType();
2004 Type sampledElementType = imageType.getElementType();
2005 if (!llvm::isa<NoneType>(sampledElementType) &&
2006 elementType != sampledElementType)
2008 "the component type of result must be the same as sampled type of the "
2009 "underlying image type");
2011 spirv::Dim imageDim = imageType.getDim();
2012 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
2014 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
2015 imageDim != spirv::Dim::Rect)
2017 "the Dim operand of the underlying image type must be 2D, Cube, or "
2020 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
2021 return emitOpError(
"the MS operand of the underlying image type must be 0");
2023 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
2024 auto operandArguments = getOperandArguments();
2059 llvm::cast<spirv::ImageType>(getImage().getType());
2060 Type resultType = getResult().getType();
2062 spirv::Dim dim = imageType.
getDim();
2066 case spirv::Dim::Dim1D:
2067 case spirv::Dim::Dim2D:
2068 case spirv::Dim::Dim3D:
2069 case spirv::Dim::Cube:
2070 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2071 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2072 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2074 "if Dim is 1D, 2D, 3D, or Cube, "
2075 "it must also have either an MS of 1 or a Sampled of 0 or 2");
2077 case spirv::Dim::Buffer:
2078 case spirv::Dim::Rect:
2081 return emitError(
"the Dim operand of the image type must "
2082 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2085 unsigned componentNumber = 0;
2087 case spirv::Dim::Dim1D:
2088 case spirv::Dim::Buffer:
2089 componentNumber = 1;
2091 case spirv::Dim::Dim2D:
2092 case spirv::Dim::Cube:
2093 case spirv::Dim::Rect:
2094 componentNumber = 2;
2096 case spirv::Dim::Dim3D:
2097 componentNumber = 3;
2103 if (imageType.
getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2104 componentNumber += 1;
2106 unsigned resultComponentNumber = 1;
2107 if (
auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2108 resultComponentNumber = resultVectorType.getNumElements();
2110 if (componentNumber != resultComponentNumber)
2111 return emitError(
"expected the result to have ")
2112 << componentNumber <<
" component(s), but found "
2113 << resultComponentNumber <<
" component(s)";
2123 if (getVector().getType() != getType())
2124 return emitOpError(
"vector operand and result type mismatch");
2125 auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2126 if (getScalar().getType() != scalarType)
2127 return emitOpError(
"scalar operand and result element type match");
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
static int64_t getNumElements(ShapedType type)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
ImageArrayedInfo getArrayedInfo() const
ImageSamplerUseInfo getSamplerUseInfo() const
ImageSamplingInfo getSamplingInfo() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.