30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/Support/FormatVariadic.h"
84 auto fnType = type.
dyn_cast<FunctionType>();
86 parser.
emitError(loc,
"expected function type");
91 result.
addTypes(fnType.getResults());
102 assert(op->
getNumResults() == 1 &&
"op should have one result");
108 [&](
Type type) { return type != resultType; })) {
117 p <<
" : " << resultType;
127 if (isa<FunctionOpInterface>(op))
139 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
143 auto valueAttr = constOp.getValue();
144 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
145 if (!integerValueAttr) {
149 if (integerValueAttr.getType().isSignlessInteger())
150 value = integerValueAttr.getInt();
152 value = integerValueAttr.getSInt();
157 template <
typename Ty>
161 if (enumValues.empty()) {
165 enumValStrs.reserve(enumValues.size());
166 for (
auto val : enumValues) {
167 enumValStrs.emplace_back(stringifyFn(val));
174 template <
typename EnumClass>
177 StringRef attrName = spirv::attributeName<EnumClass>()) {
184 if (!attrVal.
isa<StringAttr>())
185 return parser.
emitError(loc,
"expected ")
186 << attrName <<
" attribute specified as string";
188 spirv::symbolizeEnum<EnumClass>(attrVal.
cast<StringAttr>().getValue());
191 << attrName <<
" attribute specification: " << attrVal;
192 value = *attrOptional;
199 template <
typename EnumAttrClass,
200 typename EnumClass =
typename EnumAttrClass::ValueType>
203 StringRef attrName = spirv::attributeName<EnumClass>()) {
214 template <
typename EnumAttrClass,
215 typename EnumClass =
typename EnumAttrClass::ValueType>
219 StringRef attrName = spirv::attributeName<EnumClass>()) {
229 template <
typename EnumAttrClass,
typename EnumClass>
232 StringRef attrName = spirv::attributeName<EnumClass>()) {
236 parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
244 builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
262 spirv::MemoryAccess memoryAccessAttr;
263 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
267 if (spirv::bitEnumContainsAll(memoryAccessAttr,
268 spirv::MemoryAccess::Aligned)) {
293 spirv::MemoryAccess memoryAccessAttr;
294 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
298 if (spirv::bitEnumContainsAll(memoryAccessAttr,
299 spirv::MemoryAccess::Aligned)) {
312 template <
typename MemoryOpTy>
316 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
317 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
319 if (
auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
320 : memoryOp.getMemoryAccess())) {
323 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"";
325 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
327 if (
auto alignment = (alignmentAttrValue ? alignmentAttrValue
328 : memoryOp.getAlignment())) {
330 printer <<
", " << *alignment;
335 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
342 template <
typename MemoryOpTy>
346 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
347 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
352 if (
auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
353 : memoryOp.getMemoryAccess())) {
356 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"";
358 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
360 if (
auto alignment = (alignmentAttrValue ? alignmentAttrValue
361 : memoryOp.getAlignment())) {
363 printer <<
", " << *alignment;
368 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
372 spirv::ImageOperandsAttr &attr) {
377 spirv::ImageOperands imageOperands;
381 attr = spirv::ImageOperandsAttr::get(parser.
getContext(), imageOperands);
387 spirv::ImageOperandsAttr attr) {
389 auto strImageOperands = stringifyImageOperands(attr.getValue());
390 printer <<
"[\"" << strImageOperands <<
"\"]";
394 template <
typename Op>
396 spirv::ImageOperandsAttr attr,
399 if (operands.empty())
402 return imageOp.
emitError(
"the Image Operands should encode what operands "
403 "follow, as per Image Operands");
407 spirv::ImageOperands noSupportOperands =
408 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
409 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
410 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
411 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
412 spirv::ImageOperands::MakeTexelAvailable |
413 spirv::ImageOperands::MakeTexelVisible |
414 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
416 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
417 llvm_unreachable(
"unimplemented operands of Image Operands");
423 bool requireSameBitWidth =
true,
424 bool skipBitWidthCheck =
false) {
426 if (skipBitWidthCheck)
433 if (
auto vectorType = operandType.
dyn_cast<VectorType>()) {
434 operandType = vectorType.getElementType();
438 if (
auto coopMatrixType =
440 operandType = coopMatrixType.getElementType();
445 if (
auto jointMatrixType =
447 operandType = jointMatrixType.getElementType();
454 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
456 if (requireSameBitWidth) {
457 if (!isSameBitWidth) {
459 "expected the same bit widths for operand type and result "
460 "type, but provided ")
461 << operandType <<
" and " << resultType;
466 if (isSameBitWidth) {
468 "expected the different bit widths for operand type and result "
469 "type, but provided ")
470 << operandType <<
" and " << resultType;
475 template <
typename MemoryOpTy>
480 auto *op = memoryOp.getOperation();
482 if (!memAccessAttr) {
486 return memoryOp.emitOpError(
487 "invalid alignment specification without aligned memory access "
493 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
496 return memoryOp.emitOpError(
"invalid memory access specifier: ")
500 if (spirv::bitEnumContainsAll(memAccess.getValue(),
501 spirv::MemoryAccess::Aligned)) {
503 return memoryOp.emitOpError(
"missing alignment value");
507 return memoryOp.emitOpError(
508 "invalid alignment specification with non-aligned memory access "
519 template <
typename MemoryOpTy>
524 auto *op = memoryOp.getOperation();
526 if (!memAccessAttr) {
530 return memoryOp.emitOpError(
531 "invalid alignment specification without aligned memory access "
537 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
540 return memoryOp.emitOpError(
"invalid memory access specifier: ")
544 if (spirv::bitEnumContainsAll(memAccess.getValue(),
545 spirv::MemoryAccess::Aligned)) {
547 return memoryOp.emitOpError(
"missing alignment value");
551 return memoryOp.emitOpError(
552 "invalid alignment specification with non-aligned memory access "
567 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
568 spirv::MemorySemantics::Release |
569 spirv::MemorySemantics::AcquireRelease |
570 spirv::MemorySemantics::SequentiallyConsistent;
573 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
576 "expected at most one of these four memory constraints "
577 "to be set: `Acquire`, `Release`,"
578 "`AcquireRelease` or `SequentiallyConsistent`");
583 template <
typename LoadStoreOpTy>
593 return op.emitOpError(
"mismatch in result type and pointer type");
598 template <
typename BlockReadWriteOpTy>
602 if (
auto valVecTy = valType.dyn_cast<VectorType>())
603 valType = valVecTy.getElementType();
606 return op.emitOpError(
"mismatch in result type and pointer type");
613 auto builtInName = llvm::convertToSnakeFromCamelCase(
614 stringifyDecoration(spirv::Decoration::BuiltIn));
619 stringifyDecoration(spirv::Decoration::DescriptorSet));
620 auto bindingName = llvm::convertToSnakeFromCamelCase(
621 stringifyDecoration(spirv::Decoration::Binding));
652 stringifyDecoration(spirv::Decoration::DescriptorSet));
653 auto bindingName = llvm::convertToSnakeFromCamelCase(
654 stringifyDecoration(spirv::Decoration::Binding));
657 if (descriptorSet && binding) {
660 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
665 auto builtInName = llvm::convertToSnakeFromCamelCase(
666 stringifyDecoration(spirv::Decoration::BuiltIn));
667 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
668 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
669 elidedAttrs.push_back(builtInName);
686 if (
auto vectorType = type.
dyn_cast<VectorType>()) {
687 assert(vectorType.getElementType().isIntOrFloat());
688 return vectorType.getNumElements() *
689 vectorType.getElementType().getIntOrFloatBitWidth();
691 llvm_unreachable(
"unhandled bit width computation for type");
700 if (indices.empty()) {
701 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
705 for (
auto index : indices) {
707 if (cType.hasCompileTimeKnownNumElements() &&
709 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
710 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
713 type = cType.getElementType(index);
715 emitErrorFn(
"cannot extract from non-composite type ")
716 << type <<
" with index " << index;
726 auto indicesArrayAttr = indices.
dyn_cast<ArrayAttr>();
727 if (!indicesArrayAttr) {
728 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
731 if (indicesArrayAttr.empty()) {
732 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
737 for (
auto indexAttr : indicesArrayAttr) {
738 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
740 emitErrorFn(
"expected an 32-bit integer for index, but found '")
744 indexVals.push_back(indexIntAttr.getInt());
766 return !block.
empty() && std::next(block.
begin()) == block.
end() &&
767 isa<spirv::MergeOp>(block.
front());
770 template <
typename ExtendedBinaryOp>
772 auto resultType = op.getType().template cast<spirv::StructType>();
773 if (resultType.getNumElements() != 2)
774 return op.emitOpError(
"expected result struct type containing two members");
776 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
777 resultType.getElementType(0),
778 resultType.getElementType(1)}))
779 return op.emitOpError(
780 "expected all operand types and struct member types are the same");
798 if (!structType || structType.getNumElements() != 2)
799 return parser.
emitError(loc,
"expected spirv.struct type with two members");
826 spirv::MemorySemantics memoryScope;
831 if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
833 parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
841 return parser.
emitError(loc,
"expected pointer type");
844 operandTypes.push_back(ptrType);
846 operandTypes.push_back(ptrType.getPointeeType());
857 printer << spirv::stringifyScope(scopeAttr.getValue()) <<
"\" \"";
858 auto memorySemanticsAttr =
860 printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
864 template <
typename T>
878 template <
typename ExpectedElementType>
882 if (!elementType.isa<ExpectedElementType>())
883 return op->
emitOpError() <<
"pointer operand must point to an "
884 << stringifyTypeName<ExpectedElementType>()
885 <<
" value, found " << elementType;
889 if (valueType != elementType)
890 return op->
emitOpError(
"expected value to have the same type as the "
891 "pointer operand's pointee type ")
892 << elementType <<
", but found " << valueType;
894 auto memorySemantics =
905 spirv::Scope executionScope;
906 spirv::GroupOperation groupOperation;
908 if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
910 parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
915 std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
930 if (clusterSizeInfo) {
947 << stringifyGroupOperation(groupOp
948 ->getAttrOfType<spirv::GroupOperationAttr>(
962 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
964 "execution scope must be 'Workgroup' or 'Subgroup'");
966 spirv::GroupOperation operation =
969 if (operation == spirv::GroupOperation::ClusteredReduce &&
971 return groupOp->
emitOpError(
"cluster size operand must be provided for "
972 "'ClusteredReduce' group operation");
975 int32_t clusterSize = 0;
980 "cluster size operand must come from a constant op");
982 if (!llvm::isPowerOf2_32(clusterSize))
984 "cluster size operand must be a power of two");
993 if (
auto vecType = operandType.
dyn_cast<VectorType>())
994 return VectorType::get(vecType.getNumElements(), resultType);
1000 return op->
emitError(
"expected the same type for the first operand and "
1001 "result, but provided ")
1015 emitError(baseLoc,
"'spirv.AccessChain' op expected a pointer "
1016 "to composite type, but provided ")
1021 auto resultType = ptrType.getPointeeType();
1022 auto resultStorageClass = ptrType.getStorageClass();
1025 for (
auto indexSSA : indices) {
1030 "'spirv.AccessChain' op cannot extract from non-composite type ")
1031 << resultType <<
" with index " << index;
1036 Operation *op = indexSSA.getDefiningOp();
1038 emitError(baseLoc,
"'spirv.AccessChain' op index must be an "
1039 "integer spirv.Constant to access "
1040 "element of spirv.struct");
1049 "'spirv.AccessChain' index must be an integer spirv.Constant to "
1050 "access element of spirv.struct, but provided ")
1054 if (index < 0 ||
static_cast<uint64_t
>(index) >= cType.getNumElements()) {
1055 emitError(baseLoc,
"'spirv.AccessChain' op index ")
1056 << index <<
" out of bounds for " << resultType;
1060 resultType = cType.getElementType(index);
1068 assert(type &&
"Unable to deduce return type based on basePtr and indices");
1069 build(builder, state, type, basePtr, indices);
1089 if (indicesInfo.empty()) {
1091 "'spirv.AccessChain' op expected at "
1092 "least one index ");
1100 if (indicesTypes.size() != indicesInfo.size()) {
1102 result.
location,
"'spirv.AccessChain' op indices types' count must be "
1103 "equal to indices info count");
1119 template <
typename Op>
1121 printer <<
' ' << op.getBasePtr() <<
'[' << indices
1122 <<
"] : " << op.getBasePtr().getType() <<
", " << indices.
getTypes();
1129 template <
typename Op>
1132 indices, accessChainOp.
getLoc());
1136 auto providedResultType =
1137 accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1138 if (!providedResultType)
1140 "result type must be a pointer, but provided")
1141 << providedResultType;
1143 if (resultType != providedResultType)
1144 return accessChainOp.
emitOpError(
"invalid result type: expected ")
1145 << resultType <<
", but provided " << providedResultType;
1159 spirv::GlobalVariableOp var) {
1160 build(builder, state, var.getType(), SymbolRefAttr::get(var));
1164 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1166 getVariableAttr()));
1168 return emitOpError(
"expected spirv.GlobalVariable symbol");
1170 if (getPointer().getType() != varOp.getType()) {
1172 "result type mismatch with the referenced global variable's type");
1177 template <
typename T>
1179 printer <<
" \"" << stringifyScope(atomOp.getMemoryScope()) <<
"\" \""
1180 << stringifyMemorySemantics(atomOp.getEqualSemantics()) <<
"\" \""
1181 << stringifyMemorySemantics(atomOp.getUnequalSemantics()) <<
"\" "
1182 << atomOp.getOperands() <<
" : " << atomOp.getPointer().getType();
1187 spirv::Scope memoryScope;
1188 spirv::MemorySemantics equalSemantics, unequalSemantics;
1191 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1193 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1195 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1206 return parser.
emitError(loc,
"expected pointer type");
1210 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1217 template <
typename T>
1223 if (atomOp.getType() != atomOp.getValue().getType())
1224 return atomOp.emitOpError(
"value operand must have the same type as the op "
1225 "result, but found ")
1226 << atomOp.getValue().getType() <<
" vs " << atomOp.getType();
1228 if (atomOp.getType() != atomOp.getComparator().getType())
1229 return atomOp.emitOpError(
1230 "comparator operand must have the same type as the op "
1231 "result, but found ")
1232 << atomOp.getComparator().getType() <<
" vs " << atomOp.getType();
1234 Type pointeeType = atomOp.getPointer()
1236 .template cast<spirv::PointerType>()
1238 if (atomOp.getType() != pointeeType)
1239 return atomOp.emitOpError(
1240 "pointer operand's pointee type must have the same "
1241 "as the op result type, but found ")
1242 << pointeeType <<
" vs " << atomOp.getType();
1255 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1303 printer <<
" \"" << stringifyScope(getMemoryScope()) <<
"\" \""
1304 << stringifyMemorySemantics(getSemantics()) <<
"\" " << getOperands()
1305 <<
" : " << getPointer().getType();
1310 spirv::Scope memoryScope;
1311 spirv::MemorySemantics semantics;
1314 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1316 parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1327 return parser.
emitError(loc,
"expected pointer type");
1329 if (parser.
resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1337 if (getType() != getValue().getType())
1338 return emitOpError(
"value operand must have the same type as the op "
1339 "result, but found ")
1340 << getValue().getType() <<
" vs " << getType();
1344 if (getType() != pointeeType)
1345 return emitOpError(
"pointer operand's pointee type must have the same "
1346 "as the op result type, but found ")
1347 << pointeeType <<
" vs " << getType();
1357 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1373 return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1389 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1405 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1421 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1437 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1453 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1469 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1485 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1501 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1517 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1535 auto operandType = getOperand().getType();
1536 auto resultType = getResult().getType();
1537 if (operandType == resultType) {
1538 return emitError(
"result type must be different from operand type");
1543 "unhandled bit cast conversion from pointer type to non-pointer type");
1548 "unhandled bit cast conversion from non-pointer type to pointer type");
1552 if (operandBitWidth != resultBitWidth) {
1553 return emitOpError(
"mismatch in result type bitwidth ")
1554 << resultBitWidth <<
" and operand type bitwidth "
1569 if (operandStorage != spirv::StorageClass::Workgroup &&
1570 operandStorage != spirv::StorageClass::CrossWorkgroup &&
1571 operandStorage != spirv::StorageClass::Function)
1572 return emitError(
"pointer must point to the Workgroup, CrossWorkgroup"
1573 ", or Function Storage Class");
1576 if (resultStorage != spirv::StorageClass::Generic)
1577 return emitError(
"result type must be of storage class Generic");
1579 Type operandPointeeType = operandType.getPointeeType();
1581 if (operandPointeeType != resultPointeeType)
1582 return emitOpError(
"pointer operand's pointee type must have the same "
1583 "as the op result type, but found ")
1584 << operandPointeeType <<
" vs " << resultPointeeType;
1597 if (operandStorage != spirv::StorageClass::Generic)
1598 return emitError(
"pointer type must be of storage class Generic");
1601 if (resultStorage != spirv::StorageClass::Workgroup &&
1602 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1603 resultStorage != spirv::StorageClass::Function)
1604 return emitError(
"result must point to the Workgroup, CrossWorkgroup, "
1605 "or Function Storage Class");
1607 Type operandPointeeType = operandType.getPointeeType();
1609 if (operandPointeeType != resultPointeeType)
1610 return emitOpError(
"pointer operand's pointee type must have the same "
1611 "as the op result type, but found ")
1612 << operandPointeeType <<
" vs " << resultPointeeType;
1625 if (operandStorage != spirv::StorageClass::Generic)
1626 return emitError(
"pointer type must be of storage class Generic");
1629 if (resultStorage != spirv::StorageClass::Workgroup &&
1630 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1631 resultStorage != spirv::StorageClass::Function)
1632 return emitError(
"result must point to the Workgroup, CrossWorkgroup, "
1633 "or Function Storage Class");
1635 Type operandPointeeType = operandType.getPointeeType();
1637 if (operandPointeeType != resultPointeeType)
1638 return emitOpError(
"pointer operand's pointee type must have the same "
1639 "as the op result type, but found ")
1640 << operandPointeeType <<
" vs " << resultPointeeType;
1649 assert(index == 0 &&
"invalid successor index");
1658 spirv::BranchConditionalOp::getSuccessorOperands(
unsigned index) {
1659 assert(index < 2 &&
"invalid successor index");
1661 ? getTrueTargetOperandsMutable()
1662 : getFalseTargetOperandsMutable());
1679 IntegerAttr trueWeight, falseWeight;
1683 if (parser.
parseAttribute(trueWeight, i32Type,
"weight", weights) ||
1685 parser.
parseAttribute(falseWeight, i32Type,
"weight", weights) ||
1708 result.
addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1710 {1, static_cast<int32_t>(trueOperands.size()),
1711 static_cast<int32_t>(falseOperands.size())}));
1717 printer <<
' ' << getCondition();
1719 if (
auto weights = getBranchWeights()) {
1721 llvm::interleaveComma(weights->getValue(), printer, [&](
Attribute a) {
1722 printer << a.cast<IntegerAttr>().getInt();
1734 if (
auto weights = getBranchWeights()) {
1735 if (weights->getValue().size() != 2) {
1736 return emitOpError(
"must have exactly two branch weights");
1738 if (llvm::all_of(*weights, [](
Attribute attr) {
1739 return attr.
cast<IntegerAttr>().getValue().
isZero();
1741 return emitOpError(
"branch weights cannot both be zero");
1753 operand_range constituents = this->getConstituents();
1756 if (constituents.size() != 1)
1757 return emitOpError(
"has incorrect number of operands: expected ")
1758 <<
"1, but provided " << constituents.size();
1759 if (coopType.getElementType() != constituents.front().getType())
1760 return emitOpError(
"operand type mismatch: expected operand type ")
1761 << coopType.getElementType() <<
", but provided "
1762 << constituents.front().getType();
1767 if (constituents.size() != 1)
1768 return emitOpError(
"has incorrect number of operands: expected ")
1769 <<
"1, but provided " << constituents.size();
1770 if (jointType.getElementType() != constituents.front().getType())
1771 return emitOpError(
"operand type mismatch: expected operand type ")
1772 << jointType.getElementType() <<
", but provided "
1773 << constituents.front().getType();
1777 if (constituents.size() == cType.getNumElements()) {
1778 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1779 if (constituents[index].getType() != cType.getElementType(index)) {
1780 return emitOpError(
"operand type mismatch: expected operand type ")
1781 << cType.getElementType(index) <<
", but provided "
1782 << constituents[index].getType();
1790 auto resultType = cType.dyn_cast<VectorType>();
1793 "expected to return a vector or cooperative matrix when the number of "
1794 "constituents is less than what the result needs");
1797 for (
Value component : constituents) {
1798 if (!component.getType().isa<VectorType>() &&
1799 !component.getType().isIntOrFloat())
1800 return emitOpError(
"operand type mismatch: expected operand to have "
1801 "a scalar or vector type, but provided ")
1802 << component.getType();
1804 Type elementType = component.getType();
1805 if (
auto vectorType = component.getType().dyn_cast<VectorType>()) {
1806 sizes.push_back(vectorType.getNumElements());
1807 elementType = vectorType.getElementType();
1812 if (elementType != resultType.getElementType())
1813 return emitOpError(
"operand element type mismatch: expected to be ")
1814 << resultType.getElementType() <<
", but provided " << elementType;
1816 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1817 if (totalCount != cType.getNumElements())
1818 return emitOpError(
"has incorrect number of operands: expected ")
1819 << cType.getNumElements() <<
", but provided " << totalCount;
1836 build(builder, state, elementType, composite, indexAttr);
1855 getElementType(compositeType, indicesAttr, parser, attrLocation);
1864 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
1869 auto indicesArrayAttr =
getIndices().dyn_cast<ArrayAttr>();
1871 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1875 if (resultType != getType()) {
1876 return emitOpError(
"invalid result type: expected ")
1877 << resultType <<
" but provided " << getType();
1891 build(builder, state, composite.
getType(),
object, composite, indexAttr);
1897 Type objectType, compositeType;
1912 auto indicesArrayAttr =
getIndices().dyn_cast<ArrayAttr>();
1914 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1918 if (objectType != getObject().getType()) {
1919 return emitOpError(
"object operand type should be ")
1920 << objectType <<
", but found " << getObject().getType();
1923 if (getComposite().getType() != getType()) {
1924 return emitOpError(
"result type should be the same as "
1925 "the composite type, but found ")
1926 << getComposite().getType() <<
" vs " << getType();
1933 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
1934 <<
" : " << getObject().
getType() <<
" into "
1935 << getComposite().getType();
1949 if (
auto typedAttr = value.
dyn_cast<TypedAttr>())
1950 type = typedAttr.getType();
1960 printer <<
' ' << getValue();
1961 if (getType().isa<spirv::ArrayType>())
1962 printer <<
" : " << getType();
1967 if (value.
isa<IntegerAttr, FloatAttr>()) {
1968 auto valueType = value.
cast<TypedAttr>().getType();
1969 if (valueType != opType)
1970 return op.emitOpError(
"result type (")
1971 << opType <<
") does not match value type (" << valueType <<
")";
1975 auto valueType = value.
cast<TypedAttr>().getType();
1976 if (valueType == opType)
1979 auto shapedType = valueType.dyn_cast<ShapedType>();
1981 return op.emitOpError(
"result or element type (")
1982 << opType <<
") does not match value type (" << valueType
1983 <<
"), must be the same or spirv.array";
1985 int numElements = arrayType.getNumElements();
1986 auto opElemType = arrayType.getElementType();
1988 numElements *= t.getNumElements();
1989 opElemType = t.getElementType();
1991 if (!opElemType.isIntOrFloat())
1992 return op.emitOpError(
"only support nested array result type");
1995 if (valueElemType != opElemType) {
1996 return op.emitOpError(
"result element type (")
1997 << opElemType <<
") does not match value element type ("
1998 << valueElemType <<
")";
2002 return op.emitOpError(
"result number of elements (")
2003 << numElements <<
") does not match value number of elements ("
2008 if (
auto arrayAttr = value.
dyn_cast<ArrayAttr>()) {
2011 return op.emitOpError(
2012 "must have spirv.array result type for array value");
2013 Type elemType = arrayType.getElementType();
2014 for (
Attribute element : arrayAttr.getValue()) {
2021 return op.emitOpError(
"cannot have attribute: ") << value;
2031 bool spirv::ConstantOp::isBuildableWith(
Type type) {
2046 if (
auto intType = type.
dyn_cast<IntegerType>()) {
2047 unsigned width = intType.getWidth();
2049 return builder.
create<spirv::ConstantOp>(loc, type,
2051 return builder.
create<spirv::ConstantOp>(
2054 if (
auto floatType = type.dyn_cast<
FloatType>()) {
2055 return builder.
create<spirv::ConstantOp>(
2058 if (
auto vectorType = type.dyn_cast<VectorType>()) {
2059 Type elemType = vectorType.getElementType();
2060 if (elemType.
isa<IntegerType>()) {
2061 return builder.
create<spirv::ConstantOp>(
2064 IntegerAttr::get(elemType, 0).getValue()));
2067 return builder.
create<spirv::ConstantOp>(
2070 FloatAttr::get(elemType, 0.0).getValue()));
2074 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
2077 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
2079 if (
auto intType = type.
dyn_cast<IntegerType>()) {
2080 unsigned width = intType.getWidth();
2082 return builder.
create<spirv::ConstantOp>(loc, type,
2084 return builder.
create<spirv::ConstantOp>(
2087 if (
auto floatType = type.dyn_cast<
FloatType>()) {
2088 return builder.
create<spirv::ConstantOp>(
2091 if (
auto vectorType = type.dyn_cast<VectorType>()) {
2092 Type elemType = vectorType.getElementType();
2093 if (elemType.
isa<IntegerType>()) {
2094 return builder.
create<spirv::ConstantOp>(
2097 IntegerAttr::get(elemType, 1).getValue()));
2100 return builder.
create<spirv::ConstantOp>(
2103 FloatAttr::get(elemType, 1.0).getValue()));
2107 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
2110 void mlir::spirv::ConstantOp::getAsmResultNames(
2112 Type type = getType();
2115 llvm::raw_svector_ostream specialName(specialNameBuffer);
2116 specialName <<
"cst";
2118 IntegerType intTy = type.
dyn_cast<IntegerType>();
2120 if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
2121 if (intTy && intTy.getWidth() == 1) {
2122 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
2125 if (intTy.isSignless()) {
2126 specialName << intCst.getInt();
2127 }
else if (intTy.isUnsigned()) {
2128 specialName << intCst.getUInt();
2130 specialName << intCst.getSInt();
2135 specialName <<
'_' << type;
2138 if (
auto vecType = type.
dyn_cast<VectorType>()) {
2139 specialName <<
"_vec_";
2140 specialName << vecType.getDimSize(0);
2142 Type elementType = vecType.getElementType();
2145 specialName <<
"x" << elementType;
2149 setNameFn(getResult(), specialName.str());
2152 void mlir::spirv::AddressOfOp::getAsmResultNames(
2155 llvm::raw_svector_ostream specialName(specialNameBuffer);
2156 specialName << getVariable() <<
"_addr";
2157 setNameFn(getResult(), specialName.str());
2209 spirv::ExecutionModel executionModel,
2210 spirv::FuncOp
function,
2212 build(builder, state,
2213 spirv::ExecutionModelAttr::get(builder.
getContext(), executionModel),
2214 SymbolRefAttr::get(
function), builder.
getArrayAttr(interfaceVars));
2219 spirv::ExecutionModel execModel;
2225 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2234 FlatSymbolRefAttr var;
2235 NamedAttrList attrs;
2236 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
2238 interfaceVars.push_back(var);
2251 auto interfaceVars = getInterface().getValue();
2252 if (!interfaceVars.empty()) {
2254 llvm::interleaveComma(interfaceVars, printer);
2269 spirv::FuncOp
function,
2270 spirv::ExecutionMode executionMode,
2272 build(builder, state, SymbolRefAttr::get(
function),
2273 spirv::ExecutionModeAttr::get(builder.
getContext(), executionMode),
2279 spirv::ExecutionMode execMode;
2282 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2294 values.push_back(value.
cast<IntegerAttr>().getInt());
2304 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
2305 auto values = this->getValues();
2309 llvm::interleaveComma(values, printer, [&](
Attribute a) {
2310 printer << a.
cast<IntegerAttr>().getInt();
2349 StringAttr nameAttr;
2355 bool isVariadic =
false;
2357 parser,
false, entryArgs, isVariadic, resultTypes,
2362 for (
auto &arg : entryArgs)
2363 argTypes.push_back(arg.type);
2366 TypeAttr::get(fnType));
2369 spirv::FunctionControl fnControl;
2370 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2378 assert(resultAttrs.size() == resultTypes.size());
2380 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
2381 getResAttrsAttrName(result.
name));
2396 printer, *
this, fnType.getInputs(),
2397 false, fnType.getResults());
2398 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
2402 {spirv::attributeName<spirv::FunctionControl>(),
2403 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2404 getFunctionControlAttrName()});
2407 Region &body = this->getBody();
2408 if (!body.empty()) {
2417 return emitOpError(
"cannot have more than one result");
2425 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2426 if (fnType.getNumResults() != 0)
2427 return retOp.emitOpError(
"cannot be used in functions returning value");
2428 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2429 if (fnType.getNumResults() != 1)
2430 return retOp.emitOpError(
2431 "returns 1 value but enclosing function requires ")
2432 << fnType.getNumResults() <<
" results";
2434 auto retOperandType = retOp.getValue().getType();
2435 auto fnResultType = fnType.getResult(0);
2436 if (retOperandType != fnResultType)
2437 return retOp.emitOpError(
" return value's type (")
2438 << retOperandType <<
") mismatch with function's result type ("
2439 << fnResultType <<
")";
2446 return failure(walkResult.wasInterrupted());
2450 StringRef name, FunctionType type,
2451 spirv::FunctionControl control,
2455 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
2456 state.
addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2457 builder.
getAttr<spirv::FunctionControlAttr>(control));
2463 Region *spirv::FuncOp::getCallableRegion() {
2464 return isExternal() ? nullptr : &getBody();
2473 ::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
2478 ::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
2479 return getResAttrs().value_or(
nullptr);
2487 auto fnName = getCalleeAttr();
2489 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2492 return emitOpError(
"callee function '")
2493 << fnName.getValue() <<
"' not found in nearest symbol table";
2496 auto functionType = funcOp.getFunctionType();
2498 if (getNumResults() > 1) {
2500 "expected callee function to have 0 or 1 result, but provided ")
2504 if (functionType.getNumInputs() != getNumOperands()) {
2505 return emitOpError(
"has incorrect number of operands for callee: expected ")
2506 << functionType.getNumInputs() <<
", but provided "
2507 << getNumOperands();
2510 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2511 if (getOperand(i).getType() != functionType.getInput(i)) {
2512 return emitOpError(
"operand type mismatch: expected operand type ")
2513 << functionType.getInput(i) <<
", but provided "
2514 << getOperand(i).getType() <<
" for operand number " << i;
2518 if (functionType.getNumResults() != getNumResults()) {
2520 "has incorrect number of results has for callee: expected ")
2521 << functionType.getNumResults() <<
", but provided "
2525 if (getNumResults() &&
2526 (getResult(0).getType() != functionType.getResult(0))) {
2527 return emitOpError(
"result type mismatch: expected ")
2528 << functionType.getResult(0) <<
", but provided "
2529 << getResult(0).getType();
2536 return (*this)->getAttrOfType<SymbolRefAttr>(
kCallee);
2540 return getArguments();
2587 Type type, StringRef name,
2588 unsigned descriptorSet,
unsigned binding) {
2589 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
2591 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2594 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2599 Type type, StringRef name,
2600 spirv::BuiltIn builtin) {
2601 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
2603 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2610 StringAttr nameAttr;
2636 return parser.
emitError(loc,
"expected spirv.ptr type");
2645 spirv::attributeName<spirv::StorageClass>()};
2653 if (
auto initializer = this->getInitializer()) {
2662 printer <<
" : " << getType();
2666 if (!getType().isa<spirv::PointerType>())
2667 return emitOpError(
"result must be of a !spv.ptr type");
2673 auto storageClass = this->storageClass();
2674 if (storageClass == spirv::StorageClass::Generic ||
2675 storageClass == spirv::StorageClass::Function) {
2676 return emitOpError(
"storage class cannot be '")
2677 << stringifyStorageClass(storageClass) <<
"'";
2683 (*this)->getParentOp(), init.getAttr());
2688 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2689 return emitOpError(
"initializer must be result of a "
2690 "spirv.SpecConstant or spirv.GlobalVariable op");
2702 spirv::Scope scope = getExecutionScope();
2703 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2704 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
2706 if (
auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
2707 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2708 return emitOpError(
"localid is a vector and can be with only "
2709 " 2 or 3 components, actual number is ")
2710 << localIdTy.getNumElements();
2720 spirv::Scope scope = getExecutionScope();
2721 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2722 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
2732 spirv::Scope scope = getExecutionScope();
2733 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2734 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
2739 if (
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2742 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2743 auto *idOp = getId().getDefiningOp();
2744 if (!idOp || !isa<spirv::ConstantOp,
2745 spirv::ReferenceOfOp>(idOp))
2746 return emitOpError(
"id must be the result of a constant op");
2756 template <
typename OpTy>
2758 spirv::Scope scope = op.getExecutionScope();
2759 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2760 return op.emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
2762 if (op.getOperands().back().getType().isSignedInteger())
2763 return op.emitOpError(
"second operand must be a singless/unsigned integer");
2788 spirv::StorageClass storageClass;
2797 if (
auto valVecTy = elementType.
dyn_cast<VectorType>())
2809 printer <<
" " << getPtr() <<
" : " << getType();
2826 spirv::StorageClass storageClass;
2837 if (
auto valVecTy = elementType.
dyn_cast<VectorType>())
2848 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
2849 << getValue().getType();
2864 spirv::Scope scope = getExecutionScope();
2865 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2866 return emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
3104 Value basePtr, MemoryAccessAttr memoryAccess,
3105 IntegerAttr alignment) {
3107 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3113 spirv::StorageClass storageClass;
3134 StringRef sc = stringifyStorageClass(
3135 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3136 printer <<
" \"" << sc <<
"\" " << getPtr();
3141 printer <<
" : " << getType();
3165 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3172 auto control = getLoopControl();
3174 printer <<
" control(" << spirv::stringifyLoopControl(control) <<
")";
3184 if (!llvm::hasSingleElement(srcBlock))
3187 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
3188 return branchOp && branchOp.getSuccessor() == &dstBlock;
3192 auto *op = getOperation();
3228 return emitOpError(
"last block must be the merge block with only one "
3229 "'spirv.mlir.merge' op");
3231 if (std::next(region.begin()) == region.end())
3233 "must have an entry block branching to the loop header block");
3237 if (std::next(region.begin(), 2) == region.end())
3239 "must have a loop header block branched from the entry block");
3241 Block &header = *std::next(region.begin(), 1);
3245 "entry block must only have one 'spirv.Branch' op to the second block");
3247 if (std::next(region.begin(), 3) == region.end())
3249 "requires a loop continue block branching to the loop header block");
3251 Block &cont = *std::prev(region.end(), 2);
3257 [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
3258 return emitOpError(
"second to last block must be the loop continue "
3259 "block that branches to the loop header block");
3263 for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
3264 std::prev(region.end(), 2))) {
3265 for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3266 if (block.getSuccessor(i) == &header) {
3267 return emitOpError(
"can only have the entry and loop continue "
3268 "block branching to the loop header block");
3276 Block *spirv::LoopOp::getEntryBlock() {
3277 assert(!getBody().empty() &&
"op region should not be empty!");
3278 return &getBody().front();
3281 Block *spirv::LoopOp::getHeaderBlock() {
3282 assert(!getBody().empty() &&
"op region should not be empty!");
3284 return &*std::next(getBody().begin());
3287 Block *spirv::LoopOp::getContinueBlock() {
3288 assert(!getBody().empty() &&
"op region should not be empty!");
3290 return &*std::prev(getBody().end(), 2);
3293 Block *spirv::LoopOp::getMergeBlock() {
3294 assert(!getBody().empty() &&
"op region should not be empty!");
3296 return &getBody().back();
3299 void spirv::LoopOp::addEntryAndMergeBlock() {
3300 assert(getBody().empty() &&
"entry and merge block already exist");
3301 getBody().push_back(
new Block());
3302 auto *mergeBlock =
new Block();
3303 getBody().push_back(mergeBlock);
3307 builder.
create<spirv::MergeOp>(getLoc());
3323 auto *parentOp = (*this)->getParentOp();
3324 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3326 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
3329 Block &parentLastBlock = (*this)->getParentRegion()->
back();
3331 return emitOpError(
"can only be used in the last block of "
3332 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
3341 std::optional<StringRef> name) {
3351 spirv::AddressingModel addressingModel,
3352 spirv::MemoryModel memoryModel,
3353 std::optional<VerCapExtAttr> vceTriple,
3354 std::optional<StringRef> name) {
3357 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
3359 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
3363 state.
addAttribute(getVCETripleAttrName(), *vceTriple);
3374 StringAttr nameAttr;
3379 spirv::AddressingModel addrModel;
3380 spirv::MemoryModel memoryModel;
3381 if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3383 ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3390 spirv::ModuleOp::getVCETripleAttrName(),
3407 if (std::optional<StringRef> name = getName()) {
3416 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3417 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3418 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3421 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
3422 printer <<
" requires " << *triple;
3423 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3432 Dialect *dialect = (*this)->getDialect();
3437 for (
auto &op : *getBody()) {
3439 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
3444 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3445 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
3447 return entryPointOp.emitError(
"function '")
3448 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
3450 if (
auto interface = entryPointOp.getInterface()) {
3454 return entryPointOp.emitError(
3455 "expected symbol reference for interface "
3456 "specification instead of '")
3460 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3462 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
3463 "symbol reference instead of'")
3464 << varSymRef <<
"'";
3469 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3470 funcOp, entryPointOp.getExecutionModel());
3471 auto entryPtIt = entryPoints.find(key);
3472 if (entryPtIt != entryPoints.end()) {
3473 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
3475 entryPoints[key] = entryPointOp;
3476 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3477 if (funcOp.isExternal())
3478 return op.
emitError(
"'spirv.module' cannot contain external functions");
3481 for (
auto &block : funcOp)
3482 for (
auto &op : block) {
3485 "functions in 'spirv.module' can only contain spirv.* ops");
3499 (*this)->getParentOp(), getSpecConstAttr());
3502 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3504 constType = specConstOp.getDefaultValue().getType();
3506 auto specConstCompositeOp =
3507 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3508 if (specConstCompositeOp)
3509 constType = specConstCompositeOp.getType();
3511 if (!specConstOp && !specConstCompositeOp)
3513 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
3515 if (getReference().getType() != constType)
3516 return emitOpError(
"result type mismatch with the referenced "
3517 "specialization constant's type");
3545 if (
auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
3546 auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
3547 if (!resultVectorTy) {
3548 return emitOpError(
"result expected to be of vector type when "
3549 "condition is of vector type");
3551 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3552 return emitOpError(
"result should have the same number of elements as "
3553 "the condition when condition is of vector type");
3566 spirv::SelectionControl>(parser, result))
3572 auto control = getSelectionControl();
3574 printer <<
" control(" << spirv::stringifySelectionControl(control) <<
")";
3581 auto *op = getOperation();
3612 return emitOpError(
"last block must be the merge block with only one "
3613 "'spirv.mlir.merge' op");
3615 if (std::next(region.begin()) == region.end())
3616 return emitOpError(
"must have a selection header block");
3621 Block *spirv::SelectionOp::getHeaderBlock() {
3622 assert(!getBody().empty() &&
"op region should not be empty!");
3624 return &getBody().front();
3627 Block *spirv::SelectionOp::getMergeBlock() {
3628 assert(!getBody().empty() &&
"op region should not be empty!");
3630 return &getBody().back();
3633 void spirv::SelectionOp::addMergeBlock() {
3634 assert(getBody().empty() &&
"entry and merge block already exist");
3635 auto *mergeBlock =
new Block();
3636 getBody().push_back(mergeBlock);
3640 builder.
create<spirv::MergeOp>(getLoc());
3643 spirv::SelectionOp spirv::SelectionOp::createIfThen(
3649 selectionOp.addMergeBlock();
3650 Block *mergeBlock = selectionOp.getMergeBlock();
3651 Block *thenBlock =
nullptr;
3658 builder.
create<spirv::BranchOp>(loc, mergeBlock);
3665 builder.
create<spirv::BranchConditionalOp>(
3666 loc, condition, thenBlock,
3680 StringAttr nameAttr;
3689 IntegerAttr specIdAttr;
3707 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
3709 printer <<
" = " << getDefaultValue();
3713 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
3714 if (specID.getValue().isNegative())
3715 return emitOpError(
"SpecId cannot be negative");
3717 auto value = getDefaultValue();
3718 if (value.
isa<IntegerAttr, FloatAttr>()) {
3721 return emitOpError(
"default value bitwidth disallowed");
3725 "default value can only be a bool, integer, or float scalar");
3734 spirv::StorageClass storageClass;
3755 StringRef sc = stringifyStorageClass(
3756 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3757 printer <<
" \"" << sc <<
"\" " << getPtr() <<
", " << getValue();
3761 printer <<
" : " << getValue().getType();
3778 auto *block = (*this)->getBlock();
3781 if (block->isEntryBlock())
3782 return emitOpError(
"cannot be used in reachable block");
3783 if (block->hasNoPredecessors())
3799 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
3821 return parser.
emitError(loc,
"expected spirv.ptr type");
3832 ptrType.getStorageClass());
3833 result.
addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3840 spirv::attributeName<spirv::StorageClass>()};
3842 if (getNumOperands() != 0)
3843 printer <<
" init(" << getInitializer() <<
")";
3846 printer <<
" : " << getType();
3853 if (getStorageClass() != spirv::StorageClass::Function) {
3855 "can only be used to model function-level variables. Use "
3856 "spirv.GlobalVariable for module-level variables.");
3860 if (getStorageClass() != pointerType.getStorageClass())
3862 "storage class must match result pointer's storage class");
3864 if (getNumOperands() != 0) {
3867 auto *initOp = getOperand(0).getDefiningOp();
3868 if (!initOp || !isa<spirv::ConstantOp,
3869 spirv::ReferenceOfOp,
3870 spirv::AddressOfOp>(initOp))
3871 return emitOpError(
"initializer must be the result of a "
3872 "constant or spirv.GlobalVariable op");
3876 auto *op = getOperation();
3878 stringifyDecoration(spirv::Decoration::DescriptorSet));
3879 auto bindingName = llvm::convertToSnakeFromCamelCase(
3880 stringifyDecoration(spirv::Decoration::Binding));
3881 auto builtInName = llvm::convertToSnakeFromCamelCase(
3882 stringifyDecoration(spirv::Decoration::BuiltIn));
3886 return emitOpError(
"cannot have '")
3887 << attr <<
"' attribute (only allowed in spirv.GlobalVariable)";
3898 VectorType resultType = getType().cast<VectorType>();
3900 size_t numResultElements = resultType.getNumElements();
3901 if (numResultElements != getComponents().size())
3902 return emitOpError(
"result type element count (")
3903 << numResultElements
3904 <<
") mismatch with the number of component selectors ("
3905 << getComponents().size() <<
")";
3907 size_t totalSrcElements =
3911 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
3912 uint32_t index = selector.getZExtValue();
3913 if (index >= totalSrcElements &&
3914 index != std::numeric_limits<uint32_t>().
max())
3915 return emitOpError(
"component selector ")
3916 << index <<
" out of range: expected to be in [0, "
3917 << totalSrcElements <<
") or 0xffffffff";
3939 {ptrType, strideType, columnMajorType},
3949 printer <<
" " << getPointer() <<
", " << getStride() <<
", "
3950 << getColumnmajor();
3952 if (
auto memAccess = getMemoryAccess())
3953 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"]";
3954 printer <<
" : " << getPointer().getType() <<
" as " << getType();
3962 "Pointer must point to a scalar or vector type but provided ")
3964 spirv::StorageClass storage =
3966 if (storage != spirv::StorageClass::Workgroup &&
3967 storage != spirv::StorageClass::StorageBuffer &&
3968 storage != spirv::StorageClass::PhysicalStorageBuffer)
3970 "Pointer storage class must be Workgroup, StorageBuffer or "
3971 "PhysicalStorageBufferEXT but provided ")
3972 << stringifyStorageClass(storage);
3978 getResult().getType());
3999 operandInfo, {ptrType, elementType, strideType, columnMajorType},
4008 printer <<
" " << getPointer() <<
", " << getObject() <<
", " << getStride()
4009 <<
", " << getColumnmajor();
4011 if (
auto memAccess = getMemoryAccess())
4012 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"]";
4013 printer <<
" : " << getPointer().getType() <<
", " << getOperand(1).getType();
4018 getObject().getType());
4027 if (op.getC().getType() != op.getResult().getType())
4028 return op.emitOpError(
"result and third operand must have the same type");
4033 if (typeA.getRows() != typeR.
getRows() ||
4034 typeA.getColumns() != typeB.
getRows() ||
4036 return op.emitOpError(
"matrix size must match");
4037 if (typeR.
getScope() != typeA.getScope() ||
4040 return op.emitOpError(
"matrix scope must match");
4043 if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
4044 if (elementTypeA.cast<IntegerType>().getWidth() !=
4045 elementTypeB.cast<IntegerType>().getWidth())
4046 return op.emitOpError(
4047 "matrix A and B integer element types must be the same bit width");
4048 }
else if (elementTypeA != elementTypeB) {
4049 return op.emitOpError(
4050 "matrix A and B non-integer element types must match");
4053 return op.emitOpError(
"matrix accumulator element type must match");
4066 "Pointer must point to a scalar or vector type but provided ")
4068 spirv::StorageClass storage =
4070 if (storage != spirv::StorageClass::Workgroup &&
4071 storage != spirv::StorageClass::CrossWorkgroup &&
4072 storage != spirv::StorageClass::UniformConstant &&
4073 storage != spirv::StorageClass::Generic)
4074 return op->
emitError(
"Pointer storage class must be Workgroup or "
4075 "CrossWorkgroup but provided ")
4076 << stringifyStorageClass(storage);
4086 getResult().getType());
4095 getObject().getType());
4103 if (op.getC().getType() != op.getResult().getType())
4104 return op.emitOpError(
"result and third operand must have the same type");
4109 if (typeA.getRows() != typeR.
getRows() ||
4110 typeA.getColumns() != typeB.
getRows() ||
4112 return op.emitOpError(
"matrix size must match");
4113 if (typeR.
getScope() != typeA.getScope() ||
4116 return op.emitOpError(
"matrix scope must match");
4119 return op.emitOpError(
"matrix element type must match");
4132 if (
auto inputCoopmat =
4133 getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
4134 if (inputCoopmat.getElementType() != getScalar().getType())
4135 return emitError(
"input matrix components' type and scaling value must "
4136 "have the same type");
4142 if (getScalar().getType() != inputMatrix.getElementType())
4143 return emitError(
"input matrix components' type and scaling value must "
4144 "have the same type");
4156 StringRef targetStorageClass = stringifyStorageClass(
4157 getTarget().getType().cast<spirv::PointerType>().getStorageClass());
4158 printer <<
" \"" << targetStorageClass <<
"\" " << getTarget() <<
", ";
4160 StringRef sourceStorageClass = stringifyStorageClass(
4161 getSource().getType().cast<spirv::PointerType>().getStorageClass());
4162 printer <<
" \"" << sourceStorageClass <<
"\" " << getSource();
4167 getSourceMemoryAccess(),
4168 getSourceAlignment());
4174 printer <<
" : " << pointeeType;
4179 spirv::StorageClass targetStorageClass;
4182 spirv::StorageClass sourceStorageClass;
4226 if (targetType != sourceType)
4227 return emitOpError(
"both operands must be pointers to the same type");
4252 if (inputMatrix.getNumRows() != resultMatrix.
getNumColumns())
4253 return emitError(
"input matrix rows count must be equal to "
4254 "output matrix columns count");
4256 if (inputMatrix.getNumColumns() != resultMatrix.
getNumRows())
4257 return emitError(
"input matrix columns count must be equal to "
4258 "output matrix rows count");
4261 if (inputMatrix.getElementType() != resultMatrix.
getElementType())
4262 return emitError(
"input and output matrices must have the same "
4278 if (leftMatrix.getNumColumns() != rightMatrix.
getNumRows())
4279 return emitError(
"left matrix columns' count must be equal to "
4280 "the right matrix rows' count");
4285 "right and result matrices must have equal columns' count");
4289 return emitError(
"right and result matrices' component type must"
4293 if (leftMatrix.getElementType() != resultMatrix.
getElementType())
4294 return emitError(
"left and result matrices' component type"
4295 " must be the same");
4298 if (leftMatrix.getNumRows() != resultMatrix.
getNumRows())
4299 return emitError(
"left and result matrices must have equal rows' count");
4311 StringAttr compositeName;
4323 const char *attrName =
"spec_const";
4330 constituents.push_back(specConstRef);
4352 auto constituents = this->getConstituents().getValue();
4354 if (!constituents.empty())
4355 llvm::interleaveComma(constituents, printer);
4357 printer <<
") : " << getType();
4362 auto constituents = this->getConstituents().getValue();
4365 return emitError(
"result type must be a composite type, but provided ")
4369 return emitError(
"unsupported composite type ") << cType;
4371 return emitError(
"unsupported composite type ") << cType;
4373 return emitError(
"has incorrect number of operands: expected ")
4374 << cType.getNumElements() <<
", but provided "
4375 << constituents.size();
4377 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4380 auto constituentSpecConstOp =
4382 (*this)->getParentOp(), constituent.getAttr()));
4384 if (constituentSpecConstOp.getDefaultValue().getType() !=
4385 cType.getElementType(index))
4386 return emitError(
"has incorrect types of operands: expected ")
4387 << cType.getElementType(index) <<
", but provided "
4388 << constituentSpecConstOp.getDefaultValue().getType();
4426 printer <<
" wraps ";
4430 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4431 Block &block = getRegion().getBlocks().
front();
4434 return emitOpError(
"expected exactly 2 nested ops");
4439 return emitOpError(
"invalid enclosed op");
4442 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4443 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4445 "invalid operand, must be defined by a constant operation");
4459 return emitError(
"result type must be a struct type with two memebers");
4463 VectorType exponentVecTy = exponentTy.
dyn_cast<VectorType>();
4464 IntegerType exponentIntTy = exponentTy.
dyn_cast<IntegerType>();
4466 Type operandTy = getOperand().getType();
4467 VectorType operandVecTy = operandTy.
dyn_cast<VectorType>();
4470 if (significandTy != operandTy)
4471 return emitError(
"member zero of the resulting struct type must be the "
4472 "same type as the operand");
4474 if (exponentVecTy) {
4475 IntegerType componentIntTy =
4476 exponentVecTy.getElementType().dyn_cast<IntegerType>();
4477 if (!componentIntTy || componentIntTy.getWidth() != 32)
4478 return emitError(
"member one of the resulting struct type must"
4479 "be a scalar or vector of 32 bit integer type");
4480 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4481 return emitError(
"member one of the resulting struct type "
4482 "must be a scalar or vector of 32 bit integer type");
4486 if (operandVecTy && exponentVecTy &&
4487 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4490 if (operandFTy && exponentIntTy)
4493 return emitError(
"member one of the resulting struct type must have the same "
4494 "number of components as the operand type");
4502 Type significandType = getX().getType();
4503 Type exponentType = getExp().getType();
4505 if (significandType.
isa<
FloatType>() != exponentType.
isa<IntegerType>())
4506 return emitOpError(
"operands must both be scalars or vectors");
4509 if (
auto vectorType = type.
dyn_cast<VectorType>())
4510 return vectorType.getNumElements();
4515 return emitOpError(
"operands must have the same number of elements");
4525 VectorType resultType = getResult().getType().cast<VectorType>();
4526 auto sampledImageType =
4530 if (resultType.getNumElements() != 4)
4531 return emitOpError(
"result type must be a vector of four components");
4533 Type elementType = resultType.getElementType();
4534 Type sampledElementType = imageType.getElementType();
4535 if (!sampledElementType.
isa<NoneType>() && elementType != sampledElementType)
4537 "the component type of result must be the same as sampled type of the "
4538 "underlying image type");
4540 spirv::Dim imageDim = imageType.getDim();
4541 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4543 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4544 imageDim != spirv::Dim::Rect)
4546 "the Dim operand of the underlying image type must be 2D, Cube, or "
4549 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4550 return emitOpError(
"the MS operand of the underlying image type must be 0");
4552 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
4553 auto operandArguments = getOperandArguments();
4588 Type resultType = getResult().getType();
4590 spirv::Dim dim = imageType.
getDim();
4594 case spirv::Dim::Dim1D:
4595 case spirv::Dim::Dim2D:
4596 case spirv::Dim::Dim3D:
4597 case spirv::Dim::Cube:
4598 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4599 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4600 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4602 "if Dim is 1D, 2D, 3D, or Cube, "
4603 "it must also have either an MS of 1 or a Sampled of 0 or 2");
4605 case spirv::Dim::Buffer:
4606 case spirv::Dim::Rect:
4609 return emitError(
"the Dim operand of the image type must "
4610 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4613 unsigned componentNumber = 0;
4615 case spirv::Dim::Dim1D:
4616 case spirv::Dim::Buffer:
4617 componentNumber = 1;
4619 case spirv::Dim::Dim2D:
4620 case spirv::Dim::Cube:
4621 case spirv::Dim::Rect:
4622 componentNumber = 2;
4624 case spirv::Dim::Dim3D:
4625 componentNumber = 3;
4631 if (imageType.
getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4632 componentNumber += 1;
4634 unsigned resultComponentNumber = 1;
4635 if (
auto resultVectorType = resultType.dyn_cast<VectorType>())
4636 resultComponentNumber = resultVectorType.getNumElements();
4638 if (componentNumber != resultComponentNumber)
4639 return emitError(
"expected the result to have ")
4640 << componentNumber <<
" component(s), but found "
4641 << resultComponentNumber <<
" component(s)";
4663 if (indicesInfo.empty())
4671 if (indicesTypes.size() != indicesInfo.size())
4674 <<
" indices types' count must be equal to indices info count";
4688 template <
typename Op>
4691 ret[0] = op.getElement();
4692 llvm::copy(op.getIndices(), ret.begin() + 1);
4700 void spirv::InBoundsPtrAccessChainOp::build(
OpBuilder &builder,
4705 assert(type &&
"Unable to deduce return type based on basePtr and indices");
4706 build(builder, state, type, basePtr, element, indices);
4712 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4731 assert(type &&
"Unable to deduce return type based on basePtr and indices");
4732 build(builder, state, type, basePtr, element, indices);
4754 if (getVector().getType() != getType())
4755 return emitOpError(
"vector operand and result type mismatch");
4757 if (getScalar().getType() != scalarType)
4758 return emitOpError(
"scalar operand and result element type match");
4766 template <
typename Op>
4768 spirv::Scope scope = op.getExecutionScope();
4769 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
4770 return op.
emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
4801 "Not an integer dot product op?");
4802 assert(op->
getNumResults() == 1 &&
"Expected a single result");
4806 return op->
emitOpError(
"requires the same type for both vector operands");
4808 unsigned expectedNumAttrs = 0;
4809 if (
auto intTy = factorTy.
dyn_cast<IntegerType>()) {
4811 auto packedVectorFormat =
4814 if (!packedVectorFormat)
4815 return op->
emitOpError(
"requires Packed Vector Format attribute for "
4816 "integer vector operands");
4818 assert(packedVectorFormat.getValue() ==
4819 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
4820 "Unknown Packed Vector Format");
4821 if (intTy.getWidth() != 32)
4823 llvm::formatv(
"with specified Packed Vector Format ({0}) requires "
4824 "integer vector operands to be 32-bits wide",
4825 packedVectorFormat.getValue()));
4829 "with invalid format attribute for vector operands of type '{0}'",
4833 if (op->
getAttrs().size() > expectedNumAttrs)
4835 "op only supports the 'format' #spirv.packed_vector_format attribute");
4841 "requires the same accumulator operand and result types");
4845 if (factorBitWidth > resultBitWidth)
4847 llvm::formatv(
"result type has insufficient bit-width ({0} bits) "
4848 "for the specified vector operand type ({1} bits)",
4849 resultBitWidth, factorBitWidth));
4855 return spirv::Version::V_1_0;
4859 return spirv::Version::V_1_6;
4866 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
4874 static const auto dotProductCap = spirv::Capability::DotProduct;
4875 static const auto dotProductInput4x8BitPackedCap =
4876 spirv::Capability::DotProductInput4x8BitPacked;
4877 static const auto dotProductInput4x8BitCap =
4878 spirv::Capability::DotProductInput4x8Bit;
4879 static const auto dotProductInputAllCap =
4880 spirv::Capability::DotProductInputAll;
4885 if (
auto intTy = factorTy.
dyn_cast<IntegerType>()) {
4887 .
cast<spirv::PackedVectorFormatAttr>();
4888 if (formatAttr.getValue() ==
4889 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
4890 capabilities.push_back(dotProductInput4x8BitPackedCap);
4892 return capabilities;
4895 auto vecTy = factorTy.
cast<VectorType>();
4896 if (vecTy.getElementTypeBitWidth() == 8) {
4897 capabilities.push_back(dotProductInput4x8BitCap);
4898 return capabilities;
4901 capabilities.push_back(dotProductInputAllCap);
4902 return capabilities;
4905 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
4906 LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
4907 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
4908 return getIntegerDotProductExtensions(); \
4910 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
4911 return getIntegerDotProductCapabilities(*this); \
4913 std::optional<spirv::Version> OpName::getMinVersion() { \
4914 return getIntegerDotProductMinVersion(); \
4916 std::optional<spirv::Version> OpName::getMaxVersion() { \
4917 return getIntegerDotProductMaxVersion(); \
4927 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
4931 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4934 #define GET_OP_CLASSES
4935 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4940 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
Operation::operand_range getIndices(Operation *op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state)
constexpr char kValuesAttrName[]
static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix)
static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
constexpr char kClusterSize[]
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp)
constexpr char kValueAttrName[]
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
constexpr char kEqualSemanticsAttrName[]
static LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue)
constexpr char kExecutionScopeAttrName[]
static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
static unsigned getBitWidth(Type type)
constexpr char kTypeAttrName[]
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
constexpr char kUnequalSemanticsAttrName[]
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
constexpr char kIndicesAttrName[]
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer)
static ParseResult parseImageOperands(OpAsmParser &parser, spirv::ImageOperandsAttr &attr)
static StringRef stringifyTypeName()
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op)
static LogicalResult verifyAtomicUpdateOp(Operation *op)
constexpr char kBranchWeightAttrName[]
constexpr char kCompositeSpecConstituentsName[]
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
constexpr char kInterfaceAttrName[]
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
constexpr char kSourceAlignmentAttrName[]
constexpr char kMemoryScopeAttrName[]
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix)
static bool isDirectInModuleLikeOp(Operation *op)
Returns true if the given op is an module-like op that maintains a symbol table.
static LogicalResult verifyShiftOp(Operation *op)
constexpr char kGroupOperationAttrName[]
static Type getUnaryOpResultType(Type operandType)
Result of a logical op must be a scalar or vector of boolean type.
static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access attributes attached to a memory access operand/pointer.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op)
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
constexpr char kMemoryAccessAttrName[]
constexpr char kControl[]
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
constexpr char kDefaultValueAttrName[]
static LogicalResult verifyGroupOp(Op op)
static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer)
static LogicalResult verifyIntegerDotProduct(Operation *op)
StringRef stringifyTypeName< FloatType >()
static auto concatElemAndIndices(Op op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer)
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static bool isNestedInFunctionOpInterface(Operation *op)
Returns true if the given op is a function-like op or nested in a function-like op without a module-l...
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, spirv::ImageOperandsAttr attr)
static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state)
constexpr char kSourceMemoryAccessAttrName[]
constexpr char kSpecIdAttrName[]
constexpr char kAlignmentAttrName[]
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef< Ty > enumValues, function_ref< StringRef(Ty)> stringifyFn)
constexpr char kPackedVectorFormatAttrName[]
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state)
StringRef stringifyTypeName< IntegerType >()
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
constexpr char kSemanticsAttrName[]
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp)
static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op)
constexpr char kInitializerAttrName[]
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
constexpr char kFnNameAttrName[]
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
static int64_t getNumElements(ShapedType type)
static bool isZero(OpFoldResult v)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
U dyn_cast_or_null() const
bool isa() const
Casting utility functions.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
A trait used to provide symbol table functionalities to a region operation.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class models how operands are forwarded to block arguments in control flow.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
Type getElementType() const
unsigned getNumElements() const
unsigned getNumElements() const
Return the number of elements of the type.
unsigned getRows() const
return the number of rows of the matrix.
unsigned getColumns() const
return the number of columns of the matrix.
Scope getScope() const
Return the scope of the cooperative matrix.
Type getElementType() const
ImageArrayedInfo getArrayedInfo() const
ImageSamplerUseInfo getSamplerUseInfo() const
ImageSamplingInfo getSamplingInfo() const
Scope getScope() const
Return the scope of the joint matrix.
unsigned getColumns() const
return the number of columns of the matrix.
unsigned getRows() const
return the number of rows of the matrix.
Type getElementType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getElementType() const
Returns the elements' type (i.e, single element type).
unsigned getNumRows() const
Returns the number of rows.
Type getPointeeType() const
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getImageType() const
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
Type getFunctionType(Builder &builder, ArrayRef< OpAsmParser::Argument > argAttrs, ArrayRef< Type > resultTypes)
Get a function type corresponding to an array of arguments (which have types) and a set of result typ...
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context)
Returns the default target environment: SPIR-V 1.0 with Shader capability and no extra extensions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.